[LINALG] Add E2E support for `aten.where.[Scalar|ScalarSelf|ScalarOther]` ops

This commit decomposes different variants of `aten.where.*` op into
`aten.where.Self` op. It covers `aten.where.Scalar`,
`aten.where.ScalarSelf` and `aten.where.ScalarOther` ops.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
pull/719/head
Gaurav Shukla 2022-03-11 22:51:36 +05:30
parent 2597c481f6
commit 969785d1b6
8 changed files with 303 additions and 0 deletions

View File

@ -4972,6 +4972,81 @@ def Torch_AtenWhereSelfOp : Torch_Op<"aten.where.self", [
}];
}
def Torch_AtenWhereScalarOp : Torch_Op<"aten.where.Scalar", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::where.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$condition,
AnyTorchScalarType:$self,
AnyTorchScalarType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenWhereScalarOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenWhereScalarOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenWhereScalarOtherOp : Torch_Op<"aten.where.ScalarOther", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::where.ScalarOther : (Tensor, Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$condition,
AnyTorchTensorType:$self,
AnyTorchScalarType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenWhereScalarOtherOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenWhereScalarOtherOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenWhereScalarSelfOp : Torch_Op<"aten.where.ScalarSelf", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::where.ScalarSelf : (Tensor, Scalar, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$condition,
AnyTorchScalarType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenWhereScalarSelfOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenWhereScalarSelfOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenSliceTensorOp : Torch_Op<"aten.slice.Tensor", [
AllowsTypeRefinement,
ReadOnly

View File

@ -696,6 +696,60 @@ public:
};
} // namespace
// Decompose aten.where.Scalar into aten.where.self op.
namespace {
class DecomposeAtenWhereScalarOp : public OpRewritePattern<AtenWhereScalarOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenWhereScalarOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto resType = op.getType().cast<BaseTensorType>();
Value selfTensor = createRank0Tensor(rewriter, loc, resType, op.self());
Value otherTensor = createRank0Tensor(rewriter, loc, resType, op.other());
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.condition(),
selfTensor, otherTensor);
return success();
}
};
} // namespace
// Decompose aten.where.ScalarOther into aten.where.self op.
namespace {
class DecomposeAtenWhereScalarOtherOp
: public OpRewritePattern<AtenWhereScalarOtherOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenWhereScalarOtherOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto resType = op.getType().cast<BaseTensorType>();
Value otherTensor = createRank0Tensor(rewriter, loc, resType, op.other());
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.condition(),
op.self(), otherTensor);
return success();
}
};
} // namespace
// Decompose aten.where.ScalarSelf into aten.where.self op.
namespace {
class DecomposeAtenWhereScalarSelfOp
: public OpRewritePattern<AtenWhereScalarSelfOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenWhereScalarSelfOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto resType = op.getType().cast<BaseTensorType>();
Value selfTensor = createRank0Tensor(rewriter, loc, resType, op.self());
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.condition(),
selfTensor, op.other());
return success();
}
};
} // namespace
// Decompose aten.addmm into aten.mm and aten.add.Tensor op.
namespace {
class DecomposeAtenAddmmOp : public OpRewritePattern<AtenAddmmOp> {
@ -1591,6 +1645,12 @@ class DecomposeComplexOpsPass
target.addIllegalOp<AtenZerosLikeOp>();
patterns.add<DecomposeAtenExpandOp>(context);
target.addIllegalOp<AtenExpandOp>();
patterns.add<DecomposeAtenWhereScalarOp>(context);
target.addIllegalOp<AtenWhereScalarOp>();
patterns.add<DecomposeAtenWhereScalarOtherOp>(context);
target.addIllegalOp<AtenWhereScalarOtherOp>();
patterns.add<DecomposeAtenWhereScalarSelfOp>(context);
target.addIllegalOp<AtenWhereScalarSelfOp>();
patterns.add<DecomposeAtenSizeOp>(context);
target.addIllegalOp<AtenSizeOp>();
patterns.add<DecomposeAtenReshapeOp>(context);

View File

@ -605,6 +605,37 @@ ChangeResult TypeAnalyzer::visitOperation(
return incorporateKnowledge(op->getResult(0), knowledge);
}
// Promote 2nd and 3rd operands.
if (isa<AtenWhereScalarOp>(op)) {
Value lhsScalar = op->getOperand(1);
Value rhsScalar = op->getOperand(2);
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(getContext());
knowledge.dtype =
getPromotedResultType({lhsScalar.getType(), rhsScalar.getType()});
return incorporateKnowledge(op->getResult(0), knowledge);
}
// Promote 2nd and 3rd operands.
if (isa<AtenWhereScalarOtherOp>(op)) {
auto lhs = operands[1]->getValue();
Value scalar = op->getOperand(2);
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(getContext());
knowledge.dtype = getPromotedResultType(&lhs, scalar.getType());
return incorporateKnowledge(op->getResult(0), knowledge);
}
// Promote 2nd and 3rd operands.
if (isa<AtenWhereScalarSelfOp>(op)) {
auto rhs = operands[2]->getValue();
Value scalar = op->getOperand(1);
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(getContext());
knowledge.dtype = getPromotedResultType(&rhs, scalar.getType());
return incorporateKnowledge(op->getResult(0), knowledge);
}
// 2 results take dtype from first operand.
if (isa<AtenNllLossForwardOp>(op)) {
auto self = operands[0]->getValue();

View File

@ -2051,6 +2051,18 @@ module {
%1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.broadcast(%arg0, %0) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %1 : !torch.list<int>
}
func @"__torch_mlir_shape_fn.aten.where.Scalar"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list<int> {
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func @"__torch_mlir_shape_fn.aten.where.ScalarOther"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float) -> !torch.list<int> {
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func @"__torch_mlir_shape_fn.aten.where.ScalarSelf"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.broadcast(%arg0, %arg2) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func @"__torch_mlir_shape_fn.aten.lerp.Tensor"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.broadcast(%arg1, %arg2) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
%1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.broadcast(%arg0, %0) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>

View File

@ -727,6 +727,15 @@ def aten_shape_as_tensor(self: List[int]) -> List[int]:
def atenwhereself(condition: List[int], self: List[int], other: List[int]) -> List[int]:
return upstream_shape_helpers.broadcast(condition, upstream_shape_helpers.broadcast(self, other))
def atenwhereScalar(condition: List[int], self: float, other: float) -> List[int]:
return upstream_shape_helpers.unary(condition)
def atenwhereScalarOther(condition: List[int], self: List[int], other: float) -> List[int]:
return upstream_shape_helpers.broadcast(condition, self)
def atenwhereScalarSelf(condition: List[int], self: float, other: List[int]) -> List[int]:
return upstream_shape_helpers.broadcast(condition, other)
def atenlerpTensor(self: List[int], end: List[int], weight: List[int]) -> List[int]:
return upstream_shape_helpers.broadcast(self, upstream_shape_helpers.broadcast(end, weight))

View File

@ -414,6 +414,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True)
emit("aten::_unsafe_view : (Tensor, int[]) -> (Tensor)")
emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)")
emit("aten::where.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)")
emit("aten::where.ScalarOther : (Tensor, Tensor, Scalar) -> (Tensor)")
emit("aten::where.ScalarSelf : (Tensor, Scalar, Tensor) -> (Tensor)")
emit("aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)")
emit("aten::len.Tensor : (Tensor) -> (int)")
emit("aten::cpu : (Tensor) -> (Tensor)")

View File

@ -139,6 +139,65 @@ def ElementwiseWhereSelfModule_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseWhereScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.where(a > 0.5, 4.0, 8.0)
@register_test_case(module_factory=lambda: ElementwiseWhereScalarModule())
def ElementwiseWhereScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
# ==============================================================================
class ElementwiseWhereScalarOtherModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float64, True),
([-1, -1], torch.float64, True),
])
def forward(self, a, b):
return torch.where(a > 0.5, b, 8.0)
@register_test_case(module_factory=lambda: ElementwiseWhereScalarOtherModule())
def ElementwiseWhereScalarOtherModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5).double(), tu.rand(4, 5).double())
# ==============================================================================
class ElementwiseWhereScalarSelfModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float64, True),
([-1, -1], torch.float64, True),
])
def forward(self, a, b):
return torch.where(a > 0.5, 4.0, b)
@register_test_case(module_factory=lambda: ElementwiseWhereScalarSelfModule())
def ElementwiseWhereScalarSelfModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5).double(), tu.rand(4, 5).double())
# ==============================================================================
# Addition is an interesting special case of a binary op, because under the hood
# it carries a third scalar "alpha" parameter, which needs special handling.
class ElementwiseAddModule(torch.nn.Module):

View File

@ -800,3 +800,57 @@ func @torch.aten.new_empty(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[
%1 = torch.aten.new_empty %arg0, %0, %none, %none, %none, %none : !torch.vtensor<[?,?],f32>, !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
return %1 : !torch.vtensor<[2,3],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.where.Scalar(
// CHECK-SAME: %[[COND:.*]]: !torch.vtensor<[?,?,?],i1>) -> !torch.vtensor<[?,?,?],f32> {
// CHECK: %[[CST8:.*]] = torch.constant.float 8.000000e+00
// CHECK: %[[CST4:.*]] = torch.constant.float 4.000000e+00
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[ALLOC:.*]] = torch.aten.empty.memory_format %[[LIST]], %none, %none, %none, %none, %none : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[FILL_SELF:.*]] = torch.valsem.aten.fill.Scalar %[[ALLOC]], %[[CST4]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK: %[[LIST2:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[NONE2:.*]] = torch.constant.none
// CHECK: %[[ALLOC2:.*]] = torch.aten.empty.memory_format %[[LIST2]], %none_0, %none_0, %none_0, %none_0, %none_0 : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[FILL_OTHER:.*]] = torch.valsem.aten.fill.Scalar %[[ALLOC2]], %[[CST8]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK: %[[OUT:.*]] = torch.aten.where.self %[[COND]], %[[FILL_SELF]], %[[FILL_OTHER]] : !torch.vtensor<[?,?,?],i1>, !torch.vtensor<[],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[OUT]] : !torch.vtensor<[?,?,?],f32>
func @torch.aten.where.Scalar(%arg0: !torch.vtensor<[?,?,?],i1>) -> !torch.vtensor<[?,?,?],f32> {
%cst8 = torch.constant.float 8.000000e+00
%cst4 = torch.constant.float 4.000000e+00
%0 = torch.aten.where.Scalar %arg0, %cst4, %cst8 : !torch.vtensor<[?,?,?],i1>, !torch.float, !torch.float -> !torch.vtensor<[?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.where.ScalarSelf(
// CHECK-SAME: %[[COND:.*]]: !torch.vtensor<[?,?,?],i1>, %[[OTHER:.*]]: !torch.vtensor<[?,?],f64>) -> !torch.vtensor<[?,?,?],f64> {
// CHECK: %[[CST:.*]] = torch.constant.float 4.000000e+00
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[ALLOC:.*]] = torch.aten.empty.memory_format %[[LIST]], %none, %none, %none, %none, %none : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f64>
// CHECK: %[[FILL:.*]] = torch.valsem.aten.fill.Scalar %[[ALLOC]], %[[CST]] : !torch.vtensor<[],f64>, !torch.float -> !torch.vtensor<[],f64>
// CHECK: %[[OUT:.*]] = torch.aten.where.self %[[COND]], %[[FILL]], %[[OTHER]] : !torch.vtensor<[?,?,?],i1>, !torch.vtensor<[],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?,?],f64>
// CHECK: return %[[OUT]] : !torch.vtensor<[?,?,?],f64>
func @torch.aten.where.ScalarSelf(%arg0: !torch.vtensor<[?,?,?],i1>, %arg1: !torch.vtensor<[?,?],f64>) -> !torch.vtensor<[?,?,?],f64> {
%cst = torch.constant.float 4.000000e+00
%0 = torch.aten.where.ScalarSelf %arg0, %cst, %arg1 : !torch.vtensor<[?,?,?],i1>, !torch.float, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?,?],f64>
return %0 : !torch.vtensor<[?,?,?],f64>
}
// -----
// CHECK-LABEL: func @torch.aten.where.ScalarOther(
// CHECK-SAME: %[[COND:.*]]: !torch.vtensor<[?,?,?],i1>, %[[SELF:.*]]: !torch.vtensor<[?,?],f64>) -> !torch.vtensor<[?,?,?],f64> {
// CHECK: %[[CST:.*]] = torch.constant.float 4.000000e+00
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[ALLOC:.*]] = torch.aten.empty.memory_format %[[LIST]], %none, %none, %none, %none, %none : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f64>
// CHECK: %[[FILL:.*]] = torch.valsem.aten.fill.Scalar %[[ALLOC]], %[[CST]] : !torch.vtensor<[],f64>, !torch.float -> !torch.vtensor<[],f64>
// CHECK: %[[OUT:.*]] = torch.aten.where.self %[[COND]], %[[SELF]], %[[FILL]] : !torch.vtensor<[?,?,?],i1>, !torch.vtensor<[?,?],f64>, !torch.vtensor<[],f64> -> !torch.vtensor<[?,?,?],f64>
// CHECK: return %[[OUT]] : !torch.vtensor<[?,?,?],f64>
func @torch.aten.where.ScalarOther(%arg0: !torch.vtensor<[?,?,?],i1>, %arg1: !torch.vtensor<[?,?],f64>) -> !torch.vtensor<[?,?,?],f64> {
%cst = torch.constant.float 4.000000e+00
%0 = torch.aten.where.ScalarOther %arg0, %arg1, %cst : !torch.vtensor<[?,?,?],i1>, !torch.vtensor<[?,?],f64>, !torch.float -> !torch.vtensor<[?,?,?],f64>
return %0 : !torch.vtensor<[?,?,?],f64>
}