[Torch Dialect] Add support for AtenScalarTensorOp (#2085)

* add scalar_tensor op

* add dynamo pass test; needs PR2062

* try to fix

* Empty commit, trigger test

* Empty commit, trigger test

* address comments

* use dtype function

* fix decompose rule

* remove unused include

* Empty commit, trigger test

* fix test

* disable ltc

* fix dtype

---------

Co-authored-by: zhekun.zhang <zhekun.zhang@bytedance.com>
pull/2189/head snapshot-20230601.856
Zhekun Zhang 2023-05-31 20:38:50 -07:00 committed by GitHub
parent 7ab16d38cf
commit 8af3e50662
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 183 additions and 6 deletions

View File

@ -30,6 +30,7 @@ blacklist:
- arange.start - arange.start
- arange.start_step - arange.start_step
- fill.Scalar - fill.Scalar
- scalar_tensor
# Disabled in favour of functionalized alternatives # Disabled in favour of functionalized alternatives
- _reshape_alias - _reshape_alias

View File

@ -50,12 +50,6 @@ TORCHDYNAMO_XFAIL_SET = {
# %6:4 = torch.operator "aten._embedding_bag_forward_only"(%1, %3, %5, %false, %int0, %false, %none, %false, %int-1) : (!torch.tensor<*,f32>, !torch.tensor<*,si64>, !torch.tensor<*,si64>, !torch.bool, !torch.int, !torch.bool, !torch.none, !torch.bool, !torch.int) -> (!torch.tensor, !torch.tensor, !torch.tensor, !torch.tensor) # %6:4 = torch.operator "aten._embedding_bag_forward_only"(%1, %3, %5, %false, %int0, %false, %none, %false, %int-1) : (!torch.tensor<*,f32>, !torch.tensor<*,si64>, !torch.tensor<*,si64>, !torch.bool, !torch.int, !torch.bool, !torch.none, !torch.bool, !torch.int) -> (!torch.tensor, !torch.tensor, !torch.tensor, !torch.tensor)
# See also: https://github.com/pytorch/torchdynamo/issues/327 # See also: https://github.com/pytorch/torchdynamo/issues/327
"AtenEmbeddingBagSumExample_basic", "AtenEmbeddingBagSumExample_basic",
# %1 = torch.operator "aten.scalar_tensor"(%float8.000000e00, %int6, %int0, %cpu, %none) : (!torch.float, !torch.int, !torch.int, !torch.Device, !torch.none) -> !torch.tensor
"ElementwiseWhereScalarModule_basic",
"ElementwiseWhereScalarOtherModule_basic",
"ElementwiseWhereScalarSelfModule_basic",
"ElementwiseWhereScalarOtherStaticModule_basic",
"ElementwiseWhereScalarSelfStaticModule_basic",
# error: failed to legalize operation 'torch.valsem.aten.bernoulli.float' that was explicitly marked illegal # error: failed to legalize operation 'torch.valsem.aten.bernoulli.float' that was explicitly marked illegal
"BernoulliFloatModule_basic", "BernoulliFloatModule_basic",
@ -607,6 +601,10 @@ STABLEHLO_PASS_SET = {
"RsubIntModule_basic", "RsubIntModule_basic",
"RsubIntModule_noalpha_basic", "RsubIntModule_noalpha_basic",
"RsubInt0d_NumToTensor_Module_basic", "RsubInt0d_NumToTensor_Module_basic",
"ScalarTensorDefaultDtypeModule_basic",
"ScalarTensorFloat32Module_basic",
"ScalarTensorInt32Module_basic",
"ScalarTensorInt64Module_basic",
"SelectScattertModule_basic", "SelectScattertModule_basic",
"SelectScattertStaticModule_basic", "SelectScattertStaticModule_basic",
"SliceStaticModule_basic", "SliceStaticModule_basic",
@ -1064,6 +1062,10 @@ TOSA_PASS_SET = {
"PrimsViewOfModule_basic", "PrimsViewOfModule_basic",
"PrimsViewOfZeroRankModule_basic", "PrimsViewOfZeroRankModule_basic",
"DetachModule_basic", "DetachModule_basic",
"ScalarTensorDefaultDtypeModule_basic",
"ScalarTensorFloat32Module_basic",
"ScalarTensorInt32Module_basic",
"ScalarTensorInt64Module_basic",
"UnbindIntListUnpack_Module_basic", "UnbindIntListUnpack_Module_basic",
"UnbindIntGetItem_Module_basic", "UnbindIntGetItem_Module_basic",
"TensorsConcatStaticModule_basic", "TensorsConcatStaticModule_basic",

View File

@ -6427,6 +6427,33 @@ def Torch_AtenTensorIntOp : Torch_Op<"aten.tensor.int", [
}]; }];
} }
def Torch_AtenScalarTensorOp : Torch_Op<"aten.scalar_tensor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::scalar_tensor : (Scalar, int?, int?, Device?, bool?) -> (Tensor)`";
let arguments = (ins
AnyTorchScalarType:$s,
AnyTorchOptionalIntType:$dtype,
AnyTorchOptionalIntType:$layout,
AnyTorchOptionalDeviceType:$device,
AnyTorchOptionalBoolType:$pin_memory
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenScalarTensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 1);
}
void AtenScalarTensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
}
def Torch_Aten_ShapeAsTensorOp : Torch_Op<"aten._shape_as_tensor", [ def Torch_Aten_ShapeAsTensorOp : Torch_Op<"aten._shape_as_tensor", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -7146,6 +7146,22 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n" " %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.scalar_tensor\"(%arg0: !torch.float, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.list<int> {\n"
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.scalar_tensor\"(%arg0: !torch.union<float, int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" %none = torch.constant.none\n"
" %0 = torch.aten.__isnot__ %arg1, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional<int> -> !torch.int\n"
" torch.prim.If.yield %2 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %int6 : !torch.int\n"
" }\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten._shape_as_tensor\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten._shape_as_tensor\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n" " %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %1 = torch.prim.ListConstruct %0 : (!torch.int) -> !torch.list<int>\n" " %1 = torch.prim.ListConstruct %0 : (!torch.int) -> !torch.list<int>\n"

View File

@ -4356,6 +4356,36 @@ public:
}; };
} // namespace } // namespace
namespace {
// decompose aten.scalar_tensor to prim.NumToTensor.Scalar and
// aten.to.dtype_layout
class DecomposeAtenScalarTensor : public OpRewritePattern<AtenScalarTensorOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenScalarTensorOp op,
PatternRewriter &rewriter) const override {
auto resultTy = op.getResult().getType().cast<BaseTensorType>();
auto scalarTy = getBuiltInTypeForTorchScalar(op.getS().getType());
Value numToTensor = rewriter.create<PrimNumToTensorScalarOp>(
op.getLoc(),
resultTy.getWithSizesAndDtype(resultTy.getOptionalSizes(), scalarTy),
op.getS());
Value cstNone = rewriter.create<ConstantNoneOp>(op.getLoc());
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
Value dtype =
getDtypeIntValueForType(rewriter, op.getLoc(), resultTy.getDtype());
Value toDTypeLayout = rewriter.create<AtenToDtypeLayoutOp>(
op.getLoc(), op.getType(), numToTensor, dtype, op.getLayout(),
op.getDevice(), op.getPinMemory(), /*non_blocking=*/cstFalse,
/*copy=*/cstFalse, /*memory_format=*/cstNone);
rewriter.replaceOp(op, toDTypeLayout);
return success();
}
};
} // namespace
namespace { namespace {
// Decompose `aten.topk` op into `aten.sort` and `aten.slice.Tensor` op. // Decompose `aten.topk` op into `aten.sort` and `aten.slice.Tensor` op.
class DecomposeAtenTopkOp : public OpRewritePattern<AtenTopkOp> { class DecomposeAtenTopkOp : public OpRewritePattern<AtenTopkOp> {
@ -4607,6 +4637,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTopkOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenTopkOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenScalarTensor>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenScatterValueOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenScatterValueOp>(patterns);
GreedyRewriteConfig config; GreedyRewriteConfig config;

View File

@ -479,6 +479,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenCrossEntropyLossOp>(); target.addIllegalOp<AtenCrossEntropyLossOp>();
target.addIllegalOp<AtenVarMeanDimOp>(); target.addIllegalOp<AtenVarMeanDimOp>();
target.addIllegalOp<AtenTopkOp>(); target.addIllegalOp<AtenTopkOp>();
target.addIllegalOp<AtenScalarTensorOp>();
target.addIllegalOp<AtenScatterValueOp>(); target.addIllegalOp<AtenScatterValueOp>();
for (auto &opName : backendLegalOpsSet) { for (auto &opName : backendLegalOpsSet) {
target.addLegalOp( target.addLegalOp(

View File

@ -752,6 +752,16 @@ def atentensorint〡shape(t: int, dtype: Optional[int] = None, device: Opt
def atentensorbool〡shape(t: bool, dtype: Optional[int] = None, device: Optional[device] = None, requires_grad: bool = False) -> List[int]: def atentensorbool〡shape(t: bool, dtype: Optional[int] = None, device: Optional[device] = None, requires_grad: bool = False) -> List[int]:
return [] return []
def atenscalar_tensor〡shape(s: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
return []
@check_dtype_function([Invocation(-1), Invocation(-1.0)])
def atenscalar_tensor〡dtype(s: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
if dtype is not None:
return dtype
else:
return torch.float32
@check_shape_function([ @check_shape_function([
Invocation(TensorOfShape()), Invocation(TensorOfShape()),
Invocation(TensorOfShape(2, 3)), Invocation(TensorOfShape(2, 3)),

View File

@ -463,6 +463,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::tensor : (t[], int?, Device?, bool) -> (Tensor)") emit("aten::tensor : (t[], int?, Device?, bool) -> (Tensor)")
emit("aten::tensor.bool : (bool, int?, Device?, bool) -> (Tensor)") emit("aten::tensor.bool : (bool, int?, Device?, bool) -> (Tensor)")
emit("aten::tensor.int : (int, int?, Device?, bool) -> (Tensor)") emit("aten::tensor.int : (int, int?, Device?, bool) -> (Tensor)")
emit("aten::scalar_tensor : (Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::_shape_as_tensor : (Tensor) -> (Tensor)") emit("aten::_shape_as_tensor : (Tensor) -> (Tensor)")
emit("aten::all : (Tensor) -> (Tensor)") emit("aten::all : (Tensor) -> (Tensor)")
emit("aten::all.bool : (bool[]) -> (bool)") emit("aten::all.bool : (bool[]) -> (bool)")

View File

@ -3839,6 +3839,94 @@ def ConstantBoolParameterModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class ScalarTensorFloat32Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
scalar = torch.ops.aten.scalar_tensor(1.0, dtype=torch.float32)
return scalar
@register_test_case(module_factory=lambda: ScalarTensorFloat32Module())
def ScalarTensorFloat32Module_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class ScalarTensorDefaultDtypeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
scalar = torch.ops.aten.scalar_tensor(1.0)
return scalar
@register_test_case(module_factory=lambda: ScalarTensorDefaultDtypeModule())
def ScalarTensorDefaultDtypeModule_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class ScalarTensorInt64Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
scalar = torch.ops.aten.scalar_tensor(1, dtype=torch.int64)
return scalar
@register_test_case(module_factory=lambda: ScalarTensorInt64Module())
def ScalarTensorInt64Module_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class ScalarTensorInt32Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
scalar = torch.ops.aten.scalar_tensor(1, dtype=torch.int32)
return scalar
@register_test_case(module_factory=lambda: ScalarTensorInt32Module())
def ScalarTensorInt32Module_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class AtenTopKModule(torch.nn.Module): class AtenTopKModule(torch.nn.Module):
def __init__(self): def __init__(self):