mirror of https://github.com/llvm/torch-mlir
[LINALG] Add E2E support for `aten.where.[Scalar|ScalarSelf|ScalarOther]` ops
This commit decomposes different variants of `aten.where.*` op into `aten.where.Self` op. It covers `aten.where.Scalar`, `aten.where.ScalarSelf` and `aten.where.ScalarOther` ops. Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>pull/719/head
parent
2597c481f6
commit
969785d1b6
|
@ -4972,6 +4972,81 @@ def Torch_AtenWhereSelfOp : Torch_Op<"aten.where.self", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenWhereScalarOp : Torch_Op<"aten.where.Scalar", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::where.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$condition,
|
||||
AnyTorchScalarType:$self,
|
||||
AnyTorchScalarType:$other
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenWhereScalarOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void AtenWhereScalarOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenWhereScalarOtherOp : Torch_Op<"aten.where.ScalarOther", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::where.ScalarOther : (Tensor, Tensor, Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$condition,
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchScalarType:$other
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenWhereScalarOtherOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void AtenWhereScalarOtherOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenWhereScalarSelfOp : Torch_Op<"aten.where.ScalarSelf", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::where.ScalarSelf : (Tensor, Scalar, Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$condition,
|
||||
AnyTorchScalarType:$self,
|
||||
AnyTorchTensorType:$other
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenWhereScalarSelfOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void AtenWhereScalarSelfOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenSliceTensorOp : Torch_Op<"aten.slice.Tensor", [
|
||||
AllowsTypeRefinement,
|
||||
ReadOnly
|
||||
|
|
|
@ -696,6 +696,60 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose aten.where.Scalar into aten.where.self op.
|
||||
namespace {
|
||||
class DecomposeAtenWhereScalarOp : public OpRewritePattern<AtenWhereScalarOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenWhereScalarOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
auto resType = op.getType().cast<BaseTensorType>();
|
||||
Value selfTensor = createRank0Tensor(rewriter, loc, resType, op.self());
|
||||
Value otherTensor = createRank0Tensor(rewriter, loc, resType, op.other());
|
||||
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.condition(),
|
||||
selfTensor, otherTensor);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose aten.where.ScalarOther into aten.where.self op.
|
||||
namespace {
|
||||
class DecomposeAtenWhereScalarOtherOp
|
||||
: public OpRewritePattern<AtenWhereScalarOtherOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenWhereScalarOtherOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
auto resType = op.getType().cast<BaseTensorType>();
|
||||
Value otherTensor = createRank0Tensor(rewriter, loc, resType, op.other());
|
||||
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.condition(),
|
||||
op.self(), otherTensor);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose aten.where.ScalarSelf into aten.where.self op.
|
||||
namespace {
|
||||
class DecomposeAtenWhereScalarSelfOp
|
||||
: public OpRewritePattern<AtenWhereScalarSelfOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenWhereScalarSelfOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
auto resType = op.getType().cast<BaseTensorType>();
|
||||
Value selfTensor = createRank0Tensor(rewriter, loc, resType, op.self());
|
||||
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.condition(),
|
||||
selfTensor, op.other());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose aten.addmm into aten.mm and aten.add.Tensor op.
|
||||
namespace {
|
||||
class DecomposeAtenAddmmOp : public OpRewritePattern<AtenAddmmOp> {
|
||||
|
@ -1591,6 +1645,12 @@ class DecomposeComplexOpsPass
|
|||
target.addIllegalOp<AtenZerosLikeOp>();
|
||||
patterns.add<DecomposeAtenExpandOp>(context);
|
||||
target.addIllegalOp<AtenExpandOp>();
|
||||
patterns.add<DecomposeAtenWhereScalarOp>(context);
|
||||
target.addIllegalOp<AtenWhereScalarOp>();
|
||||
patterns.add<DecomposeAtenWhereScalarOtherOp>(context);
|
||||
target.addIllegalOp<AtenWhereScalarOtherOp>();
|
||||
patterns.add<DecomposeAtenWhereScalarSelfOp>(context);
|
||||
target.addIllegalOp<AtenWhereScalarSelfOp>();
|
||||
patterns.add<DecomposeAtenSizeOp>(context);
|
||||
target.addIllegalOp<AtenSizeOp>();
|
||||
patterns.add<DecomposeAtenReshapeOp>(context);
|
||||
|
|
|
@ -605,6 +605,37 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||
}
|
||||
|
||||
// Promote 2nd and 3rd operands.
|
||||
if (isa<AtenWhereScalarOp>(op)) {
|
||||
Value lhsScalar = op->getOperand(1);
|
||||
Value rhsScalar = op->getOperand(2);
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(getContext());
|
||||
knowledge.dtype =
|
||||
getPromotedResultType({lhsScalar.getType(), rhsScalar.getType()});
|
||||
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||
}
|
||||
|
||||
// Promote 2nd and 3rd operands.
|
||||
if (isa<AtenWhereScalarOtherOp>(op)) {
|
||||
auto lhs = operands[1]->getValue();
|
||||
Value scalar = op->getOperand(2);
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(getContext());
|
||||
knowledge.dtype = getPromotedResultType(&lhs, scalar.getType());
|
||||
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||
}
|
||||
|
||||
// Promote 2nd and 3rd operands.
|
||||
if (isa<AtenWhereScalarSelfOp>(op)) {
|
||||
auto rhs = operands[2]->getValue();
|
||||
Value scalar = op->getOperand(1);
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(getContext());
|
||||
knowledge.dtype = getPromotedResultType(&rhs, scalar.getType());
|
||||
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||
}
|
||||
|
||||
// 2 results take dtype from first operand.
|
||||
if (isa<AtenNllLossForwardOp>(op)) {
|
||||
auto self = operands[0]->getValue();
|
||||
|
|
|
@ -2051,6 +2051,18 @@ module {
|
|||
%1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.broadcast(%arg0, %0) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
||||
return %1 : !torch.list<int>
|
||||
}
|
||||
func @"__torch_mlir_shape_fn.aten.where.Scalar"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list<int> {
|
||||
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
func @"__torch_mlir_shape_fn.aten.where.ScalarOther"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float) -> !torch.list<int> {
|
||||
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
func @"__torch_mlir_shape_fn.aten.where.ScalarSelf"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.list<int>) -> !torch.list<int> {
|
||||
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.broadcast(%arg0, %arg2) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
func @"__torch_mlir_shape_fn.aten.lerp.Tensor"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>) -> !torch.list<int> {
|
||||
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.broadcast(%arg1, %arg2) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
||||
%1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.broadcast(%arg0, %0) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
||||
|
|
|
@ -727,6 +727,15 @@ def aten〇_shape_as_tensor(self: List[int]) -> List[int]:
|
|||
def aten〇where〇self(condition: List[int], self: List[int], other: List[int]) -> List[int]:
|
||||
return upstream_shape_helpers.broadcast(condition, upstream_shape_helpers.broadcast(self, other))
|
||||
|
||||
def aten〇where〇Scalar(condition: List[int], self: float, other: float) -> List[int]:
|
||||
return upstream_shape_helpers.unary(condition)
|
||||
|
||||
def aten〇where〇ScalarOther(condition: List[int], self: List[int], other: float) -> List[int]:
|
||||
return upstream_shape_helpers.broadcast(condition, self)
|
||||
|
||||
def aten〇where〇ScalarSelf(condition: List[int], self: float, other: List[int]) -> List[int]:
|
||||
return upstream_shape_helpers.broadcast(condition, other)
|
||||
|
||||
def aten〇lerp〇Tensor(self: List[int], end: List[int], weight: List[int]) -> List[int]:
|
||||
return upstream_shape_helpers.broadcast(self, upstream_shape_helpers.broadcast(end, weight))
|
||||
|
||||
|
|
|
@ -414,6 +414,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True)
|
||||
emit("aten::_unsafe_view : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::where.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)")
|
||||
emit("aten::where.ScalarOther : (Tensor, Tensor, Scalar) -> (Tensor)")
|
||||
emit("aten::where.ScalarSelf : (Tensor, Scalar, Tensor) -> (Tensor)")
|
||||
emit("aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)")
|
||||
emit("aten::len.Tensor : (Tensor) -> (int)")
|
||||
emit("aten::cpu : (Tensor) -> (Tensor)")
|
||||
|
|
|
@ -139,6 +139,65 @@ def ElementwiseWhereSelfModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseWhereScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.where(a > 0.5, 4.0, 8.0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseWhereScalarModule())
|
||||
def ElementwiseWhereScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseWhereScalarOtherModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.where(a > 0.5, b, 8.0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseWhereScalarOtherModule())
|
||||
def ElementwiseWhereScalarOtherModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5).double(), tu.rand(4, 5).double())
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseWhereScalarSelfModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.where(a > 0.5, 4.0, b)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseWhereScalarSelfModule())
|
||||
def ElementwiseWhereScalarSelfModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5).double(), tu.rand(4, 5).double())
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
# Addition is an interesting special case of a binary op, because under the hood
|
||||
# it carries a third scalar "alpha" parameter, which needs special handling.
|
||||
class ElementwiseAddModule(torch.nn.Module):
|
||||
|
|
|
@ -800,3 +800,57 @@ func @torch.aten.new_empty(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[
|
|||
%1 = torch.aten.new_empty %arg0, %0, %none, %none, %none, %none : !torch.vtensor<[?,?],f32>, !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
|
||||
return %1 : !torch.vtensor<[2,3],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten.where.Scalar(
|
||||
// CHECK-SAME: %[[COND:.*]]: !torch.vtensor<[?,?,?],i1>) -> !torch.vtensor<[?,?,?],f32> {
|
||||
// CHECK: %[[CST8:.*]] = torch.constant.float 8.000000e+00
|
||||
// CHECK: %[[CST4:.*]] = torch.constant.float 4.000000e+00
|
||||
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[ALLOC:.*]] = torch.aten.empty.memory_format %[[LIST]], %none, %none, %none, %none, %none : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[FILL_SELF:.*]] = torch.valsem.aten.fill.Scalar %[[ALLOC]], %[[CST4]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[LIST2:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
// CHECK: %[[NONE2:.*]] = torch.constant.none
|
||||
// CHECK: %[[ALLOC2:.*]] = torch.aten.empty.memory_format %[[LIST2]], %none_0, %none_0, %none_0, %none_0, %none_0 : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[FILL_OTHER:.*]] = torch.valsem.aten.fill.Scalar %[[ALLOC2]], %[[CST8]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[OUT:.*]] = torch.aten.where.self %[[COND]], %[[FILL_SELF]], %[[FILL_OTHER]] : !torch.vtensor<[?,?,?],i1>, !torch.vtensor<[],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: return %[[OUT]] : !torch.vtensor<[?,?,?],f32>
|
||||
func @torch.aten.where.Scalar(%arg0: !torch.vtensor<[?,?,?],i1>) -> !torch.vtensor<[?,?,?],f32> {
|
||||
%cst8 = torch.constant.float 8.000000e+00
|
||||
%cst4 = torch.constant.float 4.000000e+00
|
||||
%0 = torch.aten.where.Scalar %arg0, %cst4, %cst8 : !torch.vtensor<[?,?,?],i1>, !torch.float, !torch.float -> !torch.vtensor<[?,?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten.where.ScalarSelf(
|
||||
// CHECK-SAME: %[[COND:.*]]: !torch.vtensor<[?,?,?],i1>, %[[OTHER:.*]]: !torch.vtensor<[?,?],f64>) -> !torch.vtensor<[?,?,?],f64> {
|
||||
// CHECK: %[[CST:.*]] = torch.constant.float 4.000000e+00
|
||||
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[ALLOC:.*]] = torch.aten.empty.memory_format %[[LIST]], %none, %none, %none, %none, %none : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f64>
|
||||
// CHECK: %[[FILL:.*]] = torch.valsem.aten.fill.Scalar %[[ALLOC]], %[[CST]] : !torch.vtensor<[],f64>, !torch.float -> !torch.vtensor<[],f64>
|
||||
// CHECK: %[[OUT:.*]] = torch.aten.where.self %[[COND]], %[[FILL]], %[[OTHER]] : !torch.vtensor<[?,?,?],i1>, !torch.vtensor<[],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?,?],f64>
|
||||
// CHECK: return %[[OUT]] : !torch.vtensor<[?,?,?],f64>
|
||||
func @torch.aten.where.ScalarSelf(%arg0: !torch.vtensor<[?,?,?],i1>, %arg1: !torch.vtensor<[?,?],f64>) -> !torch.vtensor<[?,?,?],f64> {
|
||||
%cst = torch.constant.float 4.000000e+00
|
||||
%0 = torch.aten.where.ScalarSelf %arg0, %cst, %arg1 : !torch.vtensor<[?,?,?],i1>, !torch.float, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?,?],f64>
|
||||
return %0 : !torch.vtensor<[?,?,?],f64>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten.where.ScalarOther(
|
||||
// CHECK-SAME: %[[COND:.*]]: !torch.vtensor<[?,?,?],i1>, %[[SELF:.*]]: !torch.vtensor<[?,?],f64>) -> !torch.vtensor<[?,?,?],f64> {
|
||||
// CHECK: %[[CST:.*]] = torch.constant.float 4.000000e+00
|
||||
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[ALLOC:.*]] = torch.aten.empty.memory_format %[[LIST]], %none, %none, %none, %none, %none : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f64>
|
||||
// CHECK: %[[FILL:.*]] = torch.valsem.aten.fill.Scalar %[[ALLOC]], %[[CST]] : !torch.vtensor<[],f64>, !torch.float -> !torch.vtensor<[],f64>
|
||||
// CHECK: %[[OUT:.*]] = torch.aten.where.self %[[COND]], %[[SELF]], %[[FILL]] : !torch.vtensor<[?,?,?],i1>, !torch.vtensor<[?,?],f64>, !torch.vtensor<[],f64> -> !torch.vtensor<[?,?,?],f64>
|
||||
// CHECK: return %[[OUT]] : !torch.vtensor<[?,?,?],f64>
|
||||
func @torch.aten.where.ScalarOther(%arg0: !torch.vtensor<[?,?,?],i1>, %arg1: !torch.vtensor<[?,?],f64>) -> !torch.vtensor<[?,?,?],f64> {
|
||||
%cst = torch.constant.float 4.000000e+00
|
||||
%0 = torch.aten.where.ScalarOther %arg0, %arg1, %cst : !torch.vtensor<[?,?,?],i1>, !torch.vtensor<[?,?],f64>, !torch.float -> !torch.vtensor<[?,?,?],f64>
|
||||
return %0 : !torch.vtensor<[?,?,?],f64>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue