mirror of https://github.com/llvm/torch-mlir
[Torch Dialect] Add support for AtenScalarTensorOp (#2085)
* add scalar_tensor op * add dynamo pass test; needs PR2062 * try to fix * Empty commit, trigger test * Empty commit, trigger test * address comments * use dtype function * fix decompose rule * remove unused include * Empty commit, trigger test * fix test * disable ltc * fix dtype --------- Co-authored-by: zhekun.zhang <zhekun.zhang@bytedance.com>pull/2189/head snapshot-20230601.856
parent
7ab16d38cf
commit
8af3e50662
|
@ -30,6 +30,7 @@ blacklist:
|
|||
- arange.start
|
||||
- arange.start_step
|
||||
- fill.Scalar
|
||||
- scalar_tensor
|
||||
|
||||
# Disabled in favour of functionalized alternatives
|
||||
- _reshape_alias
|
||||
|
|
|
@ -50,12 +50,6 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
# %6:4 = torch.operator "aten._embedding_bag_forward_only"(%1, %3, %5, %false, %int0, %false, %none, %false, %int-1) : (!torch.tensor<*,f32>, !torch.tensor<*,si64>, !torch.tensor<*,si64>, !torch.bool, !torch.int, !torch.bool, !torch.none, !torch.bool, !torch.int) -> (!torch.tensor, !torch.tensor, !torch.tensor, !torch.tensor)
|
||||
# See also: https://github.com/pytorch/torchdynamo/issues/327
|
||||
"AtenEmbeddingBagSumExample_basic",
|
||||
# %1 = torch.operator "aten.scalar_tensor"(%float8.000000e00, %int6, %int0, %cpu, %none) : (!torch.float, !torch.int, !torch.int, !torch.Device, !torch.none) -> !torch.tensor
|
||||
"ElementwiseWhereScalarModule_basic",
|
||||
"ElementwiseWhereScalarOtherModule_basic",
|
||||
"ElementwiseWhereScalarSelfModule_basic",
|
||||
"ElementwiseWhereScalarOtherStaticModule_basic",
|
||||
"ElementwiseWhereScalarSelfStaticModule_basic",
|
||||
|
||||
# error: failed to legalize operation 'torch.valsem.aten.bernoulli.float' that was explicitly marked illegal
|
||||
"BernoulliFloatModule_basic",
|
||||
|
@ -607,6 +601,10 @@ STABLEHLO_PASS_SET = {
|
|||
"RsubIntModule_basic",
|
||||
"RsubIntModule_noalpha_basic",
|
||||
"RsubInt0d_NumToTensor_Module_basic",
|
||||
"ScalarTensorDefaultDtypeModule_basic",
|
||||
"ScalarTensorFloat32Module_basic",
|
||||
"ScalarTensorInt32Module_basic",
|
||||
"ScalarTensorInt64Module_basic",
|
||||
"SelectScattertModule_basic",
|
||||
"SelectScattertStaticModule_basic",
|
||||
"SliceStaticModule_basic",
|
||||
|
@ -1064,6 +1062,10 @@ TOSA_PASS_SET = {
|
|||
"PrimsViewOfModule_basic",
|
||||
"PrimsViewOfZeroRankModule_basic",
|
||||
"DetachModule_basic",
|
||||
"ScalarTensorDefaultDtypeModule_basic",
|
||||
"ScalarTensorFloat32Module_basic",
|
||||
"ScalarTensorInt32Module_basic",
|
||||
"ScalarTensorInt64Module_basic",
|
||||
"UnbindIntListUnpack_Module_basic",
|
||||
"UnbindIntGetItem_Module_basic",
|
||||
"TensorsConcatStaticModule_basic",
|
||||
|
|
|
@ -6427,6 +6427,33 @@ def Torch_AtenTensorIntOp : Torch_Op<"aten.tensor.int", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenScalarTensorOp : Torch_Op<"aten.scalar_tensor", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::scalar_tensor : (Scalar, int?, int?, Device?, bool?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchScalarType:$s,
|
||||
AnyTorchOptionalIntType:$dtype,
|
||||
AnyTorchOptionalIntType:$layout,
|
||||
AnyTorchOptionalDeviceType:$device,
|
||||
AnyTorchOptionalBoolType:$pin_memory
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenScalarTensorOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 5, 1);
|
||||
}
|
||||
void AtenScalarTensorOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 5, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_Aten_ShapeAsTensorOp : Torch_Op<"aten._shape_as_tensor", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -7146,6 +7146,22 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.scalar_tensor\"(%arg0: !torch.float, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.list<int> {\n"
|
||||
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.scalar_tensor\"(%arg0: !torch.union<float, int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.int {\n"
|
||||
" %int6 = torch.constant.int 6\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0 = torch.aten.__isnot__ %arg1, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
|
||||
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
|
||||
" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional<int> -> !torch.int\n"
|
||||
" torch.prim.If.yield %2 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %int6 : !torch.int\n"
|
||||
" }\n"
|
||||
" return %1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten._shape_as_tensor\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %1 = torch.prim.ListConstruct %0 : (!torch.int) -> !torch.list<int>\n"
|
||||
|
|
|
@ -4356,6 +4356,36 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// decompose aten.scalar_tensor to prim.NumToTensor.Scalar and
|
||||
// aten.to.dtype_layout
|
||||
class DecomposeAtenScalarTensor : public OpRewritePattern<AtenScalarTensorOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenScalarTensorOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
auto resultTy = op.getResult().getType().cast<BaseTensorType>();
|
||||
auto scalarTy = getBuiltInTypeForTorchScalar(op.getS().getType());
|
||||
Value numToTensor = rewriter.create<PrimNumToTensorScalarOp>(
|
||||
op.getLoc(),
|
||||
resultTy.getWithSizesAndDtype(resultTy.getOptionalSizes(), scalarTy),
|
||||
op.getS());
|
||||
|
||||
Value cstNone = rewriter.create<ConstantNoneOp>(op.getLoc());
|
||||
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
|
||||
Value dtype =
|
||||
getDtypeIntValueForType(rewriter, op.getLoc(), resultTy.getDtype());
|
||||
Value toDTypeLayout = rewriter.create<AtenToDtypeLayoutOp>(
|
||||
op.getLoc(), op.getType(), numToTensor, dtype, op.getLayout(),
|
||||
op.getDevice(), op.getPinMemory(), /*non_blocking=*/cstFalse,
|
||||
/*copy=*/cstFalse, /*memory_format=*/cstNone);
|
||||
rewriter.replaceOp(op, toDTypeLayout);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose `aten.topk` op into `aten.sort` and `aten.slice.Tensor` op.
|
||||
class DecomposeAtenTopkOp : public OpRewritePattern<AtenTopkOp> {
|
||||
|
@ -4607,6 +4637,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenTopkOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenScalarTensor>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenScatterValueOp>(patterns);
|
||||
|
||||
GreedyRewriteConfig config;
|
||||
|
|
|
@ -479,6 +479,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenCrossEntropyLossOp>();
|
||||
target.addIllegalOp<AtenVarMeanDimOp>();
|
||||
target.addIllegalOp<AtenTopkOp>();
|
||||
target.addIllegalOp<AtenScalarTensorOp>();
|
||||
target.addIllegalOp<AtenScatterValueOp>();
|
||||
for (auto &opName : backendLegalOpsSet) {
|
||||
target.addLegalOp(
|
||||
|
|
|
@ -752,6 +752,16 @@ def aten〇tensor〇int〡shape(t: int, dtype: Optional[int] = None, device: Opt
|
|||
def aten〇tensor〇bool〡shape(t: bool, dtype: Optional[int] = None, device: Optional[device] = None, requires_grad: bool = False) -> List[int]:
|
||||
return []
|
||||
|
||||
def aten〇scalar_tensor〡shape(s: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
|
||||
return []
|
||||
|
||||
@check_dtype_function([Invocation(-1), Invocation(-1.0)])
|
||||
def aten〇scalar_tensor〡dtype(s: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
|
||||
if dtype is not None:
|
||||
return dtype
|
||||
else:
|
||||
return torch.float32
|
||||
|
||||
@check_shape_function([
|
||||
Invocation(TensorOfShape()),
|
||||
Invocation(TensorOfShape(2, 3)),
|
||||
|
|
|
@ -463,6 +463,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::tensor : (t[], int?, Device?, bool) -> (Tensor)")
|
||||
emit("aten::tensor.bool : (bool, int?, Device?, bool) -> (Tensor)")
|
||||
emit("aten::tensor.int : (int, int?, Device?, bool) -> (Tensor)")
|
||||
emit("aten::scalar_tensor : (Scalar, int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::_shape_as_tensor : (Tensor) -> (Tensor)")
|
||||
emit("aten::all : (Tensor) -> (Tensor)")
|
||||
emit("aten::all.bool : (bool[]) -> (bool)")
|
||||
|
|
|
@ -3839,6 +3839,94 @@ def ConstantBoolParameterModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ScalarTensorFloat32Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
scalar = torch.ops.aten.scalar_tensor(1.0, dtype=torch.float32)
|
||||
return scalar
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ScalarTensorFloat32Module())
|
||||
def ScalarTensorFloat32Module_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ScalarTensorDefaultDtypeModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
scalar = torch.ops.aten.scalar_tensor(1.0)
|
||||
return scalar
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ScalarTensorDefaultDtypeModule())
|
||||
def ScalarTensorDefaultDtypeModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ScalarTensorInt64Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
scalar = torch.ops.aten.scalar_tensor(1, dtype=torch.int64)
|
||||
return scalar
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ScalarTensorInt64Module())
|
||||
def ScalarTensorInt64Module_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ScalarTensorInt32Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
scalar = torch.ops.aten.scalar_tensor(1, dtype=torch.int32)
|
||||
return scalar
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ScalarTensorInt32Module())
|
||||
def ScalarTensorInt32Module_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AtenTopKModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
|
Loading…
Reference in New Issue