Implement lowering of torch.aten.lerp.Scalar (#2773)

Closes nod-ai/SHARK-Turbine#356
pull/2847/head
Ilija Kalinić 2024-01-31 18:39:38 +01:00 committed by GitHub
parent 7301aa80fd
commit 54ef18c556
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 165 additions and 0 deletions

View File

@ -1620,6 +1620,55 @@ def Torch_AtenLerp_TensorOp : Torch_Op<"aten.lerp_.Tensor", [
}]; }];
} }
def Torch_AtenLerpScalarOp : Torch_Op<"aten.lerp.Scalar", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::lerp.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$end,
AnyTorchScalarType:$weight
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLerpScalarOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenLerpScalarOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenLerp_ScalarOp : Torch_Op<"aten.lerp_.Scalar", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::lerp_.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self,
Torch_NonValueTensorType:$end,
AnyTorchScalarType:$weight
);
let results = (outs
Torch_NonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLerp_ScalarOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenLerp_ScalarOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenEqTensorOp : Torch_Op<"aten.eq.Tensor", [ def Torch_AtenEqTensorOp : Torch_Op<"aten.eq.Tensor", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -8438,6 +8438,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %0) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n" " %1 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %0) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %1 : !torch.list<int>\n" " return %1 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.lerp.Scalar\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.addcmul\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.float) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.addcmul\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.float) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg1, %arg2) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg1, %arg2) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" %1 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %0) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n" " %1 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %0) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
@ -11198,6 +11202,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n" " %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %5 : !torch.int\n" " return %5 : !torch.int\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.lerp.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.prim.ListConstruct %0#0, %1#0, %none : (!torch.int, !torch.int, !torch.none) -> !torch.list<optional<int>>\n"
" %3 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.number) -> !torch.int\n"
" %4 = torch.prim.ListConstruct %0#1, %1#1, %3 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %4) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %5 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.addcmul\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.number) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten.addcmul\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n" " %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n" " %str = torch.constant.str \"AssertionError: \"\n"

View File

@ -1895,6 +1895,35 @@ public:
}; };
} // namespace } // namespace
namespace {
class DecomposeAtenLerpScalarOp : public OpRewritePattern<AtenLerpScalarOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenLerpScalarOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto resType = op.getType().cast<BaseTensorType>();
if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}
Value cstOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
auto start = op.getSelf();
auto inputType = start.getType().cast<BaseTensorType>();
auto delta = rewriter.create<AtenSubTensorOp>(loc, inputType, op.getEnd(),
start, cstOne);
auto weightedDelta =
rewriter.create<AtenMulScalarOp>(loc, inputType, delta, op.getWeight());
auto lerp = rewriter.create<AtenAddTensorOp>(loc, inputType, start,
weightedDelta, cstOne);
rewriter.replaceOp(op, lerp);
return success();
}
};
} // namespace
// Elu = scale * max(0,x) + alpha * scale * (exp(min(0,x) * input_scale) - 1) // Elu = scale * max(0,x) + alpha * scale * (exp(min(0,x) * input_scale) - 1)
namespace { namespace {
class DecomposeAtenEluOp : public OpRewritePattern<AtenEluOp> { class DecomposeAtenEluOp : public OpRewritePattern<AtenEluOp> {
@ -6763,6 +6792,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenSeluOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenSeluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluBackwardOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluBackwardOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLerpScalarOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyStridedOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyStridedOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenEmptyStridedOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenEmptyStridedOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBucketizeTensorOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenBucketizeTensorOp>(patterns);

View File

@ -488,6 +488,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenNarrowTensorOp>(); target.addIllegalOp<AtenNarrowTensorOp>();
target.addIllegalOp<Aten_EmbeddingBagOp>(); target.addIllegalOp<Aten_EmbeddingBagOp>();
target.addIllegalOp<AtenLiftFreshCopyOp>(); target.addIllegalOp<AtenLiftFreshCopyOp>();
target.addIllegalOp<AtenLerpScalarOp>();
target.addIllegalOp<AtenIndexTensorOp>(); target.addIllegalOp<AtenIndexTensorOp>();
target.addIllegalOp<AtenMseLossOp>(); target.addIllegalOp<AtenMseLossOp>();
target.addIllegalOp<AtenRandintLowOp>(); target.addIllegalOp<AtenRandintLowOp>();

View File

@ -1116,6 +1116,8 @@ TOSA_PASS_SET = {
"ElementwiseLeakyReluModule_basic", "ElementwiseLeakyReluModule_basic",
"ElementwiseLeakyReluModule_basic", "ElementwiseLeakyReluModule_basic",
"ElementwiseLeakyReluStaticModule_basic", "ElementwiseLeakyReluStaticModule_basic",
"ElementwiseLerpScalarIntModule_basic",
"ElementwiseLerpScalarFloatModule_basic",
"ElementwiseLog2Module_basic", "ElementwiseLog2Module_basic",
"ElementwiseLogModule_basic", "ElementwiseLogModule_basic",
"ElementwiseLtDiffWidthScalarModule_basic", "ElementwiseLtDiffWidthScalarModule_basic",
@ -1496,6 +1498,8 @@ LTC_XFAIL_SET = {
"ElementwiseLogitModule_basic", "ElementwiseLogitModule_basic",
"ElementwiseRemainderScalarModule_Int_Float_basic", "ElementwiseRemainderScalarModule_Int_Float_basic",
"ElementwiseRemainderScalarModule_Bool_basic", "ElementwiseRemainderScalarModule_Bool_basic",
"ElementwiseLerpScalarIntModule_basic",
"ElementwiseLerpScalarFloatModule_basic",
"AtenIntTensorByteDtypeModule_basic", "AtenIntTensorByteDtypeModule_basic",
"AtenIntTensorCharDtypeModule_basic", "AtenIntTensorCharDtypeModule_basic",
"UpSampleNearest2dBackwardVec_basic", "UpSampleNearest2dBackwardVec_basic",

View File

@ -1245,6 +1245,9 @@ def atennan_to_num〡shape(self: List[int], nan: Optional[float] = None, posi
def atenlerpTensor〡shape(self: List[int], end: List[int], weight: List[int]) -> List[int]: def atenlerpTensor〡shape(self: List[int], end: List[int], weight: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, upstream_shape_functions.broadcast(end, weight)) return upstream_shape_functions.broadcast(self, upstream_shape_functions.broadcast(end, weight))
def atenlerpScalar〡shape(self: List[int], end: List[int], weight: float) -> List[int]:
return upstream_shape_functions.broadcast(self, end)
def atenaddcmul〡shape(self: List[int], tensor1: List[int], tensor2: List[int], value: float = 1) -> List[int]: def atenaddcmul〡shape(self: List[int], tensor1: List[int], tensor2: List[int], value: float = 1) -> List[int]:
return upstream_shape_functions.broadcast(self, upstream_shape_functions.broadcast(tensor1, tensor2)) return upstream_shape_functions.broadcast(self, upstream_shape_functions.broadcast(tensor1, tensor2))
@ -3313,6 +3316,27 @@ def atenlerpTensor〡dtype(self_rank_dtype: Tuple[int, int], end_rank_dtyp
dtypes = [self_dtype, end_dtype, weight_dtype] dtypes = [self_dtype, end_dtype, weight_dtype]
return promote_dtypes(ranks, dtypes) return promote_dtypes(ranks, dtypes)
@check_dtype_function(
_check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1)], weight=0.5) +
# Different width
[Invocation(TensorOfShape(4, 3, dtype=torch.float32),
TensorOfShape(4, 3, dtype=torch.float64),
weight=0.5),
# Different type
Invocation(TensorOfShape(4, 3, dtype=torch.int32),
TensorOfShape(4, 3, dtype=torch.float32),
weight=0.5),
Invocation(TensorOfShape(4, 3, dtype=torch.float32),
TensorOfShape(4, 3, dtype=torch.float32),
weight=2)])
def atenlerpScalar〡dtype(self_rank_dtype: Tuple[int, int], end_rank_dtype: Tuple[int, int], weight: Union[int, float, complex]) -> int:
self_rank, self_dtype = self_rank_dtype
end_rank, end_dtype = end_rank_dtype
ranks: List[Optional[int]] = [self_rank, end_rank, None]
dtypes = [self_dtype, end_dtype, get_dtype_of_scalar(weight)]
return promote_dtypes(ranks, dtypes)
@check_dtype_function( @check_dtype_function(
_check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)], error_types={torch.bool}) + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)], error_types={torch.bool}) +
# Different width # Different width

View File

@ -290,6 +290,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
"aten::logical_xor : (Tensor, Tensor) -> (Tensor)", "aten::logical_xor : (Tensor, Tensor) -> (Tensor)",
"aten::logical_not : (Tensor) -> (Tensor)", "aten::logical_not : (Tensor) -> (Tensor)",
"aten::lerp.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", "aten::lerp.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)",
"aten::lerp.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)",
"aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::gt.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::gt.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::ge.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::ge.Tensor : (Tensor, Tensor) -> (Tensor)",

View File

@ -545,6 +545,48 @@ class ElementwiseLeakyReluStaticModule(torch.nn.Module):
def ElementwiseLeakyReluStaticModule_basic(module, tu: TestUtils): def ElementwiseLeakyReluStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 5, 6, low=-1)) module.forward(tu.rand(4, 5, 6, low=-1))
# ==============================================================================
class ElementwiseLerpScalarIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
])
def forward(self, a, b):
return torch.ops.aten.lerp(a, b, weight=2)
@register_test_case(module_factory=lambda: ElementwiseLerpScalarIntModule())
def ElementwiseLerpScalarIntModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5,3), tu.rand(5,3))
class ElementwiseLerpScalarFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
])
def forward(self, a, b):
return torch.ops.aten.lerp(a, b, weight=0.5)
@register_test_case(module_factory=lambda: ElementwiseLerpScalarFloatModule())
def ElementwiseLerpScalarFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5,3), tu.rand(5,3))
# ============================================================================== # ==============================================================================