mirror of https://github.com/llvm/torch-mlir
Add decomposition of `aten.masked.tensor` op.
`aten.masked.tensor` op has been decomposed to `aten.masked.scalar` op.pull/1211/head snapshot-20220811.561
parent
d96ec64be1
commit
b1a506624c
|
@ -1928,6 +1928,55 @@ def Torch_AtenMaskedFill_ScalarOp : Torch_Op<"aten.masked_fill_.Scalar", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenMaskedFillTensorOp : Torch_Op<"aten.masked_fill.Tensor", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchTensorType:$mask,
|
||||||
|
AnyTorchTensorType:$value
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenMaskedFillTensorOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||||
|
}
|
||||||
|
void AtenMaskedFillTensorOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 3, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_AtenMaskedFill_TensorOp : Torch_Op<"aten.masked_fill_.Tensor", [
|
||||||
|
IsTrailingUnderscoreInplaceVariant,
|
||||||
|
AllowsTypeRefinement
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::masked_fill_.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchTensorType:$mask,
|
||||||
|
AnyTorchTensorType:$value
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenMaskedFill_TensorOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||||
|
}
|
||||||
|
void AtenMaskedFill_TensorOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 3, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenClampOp : Torch_Op<"aten.clamp", [
|
def Torch_AtenClampOp : Torch_Op<"aten.clamp", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -884,9 +884,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
threshold);
|
threshold);
|
||||||
return b.create<arith::SelectOp>(loc, predicate, constantZero, grad);
|
return b.create<arith::SelectOp>(loc, predicate, constantZero, grad);
|
||||||
}
|
}
|
||||||
if (auto maskedFill = dyn_cast<AtenMaskedFillScalarOp>(op)) {
|
if (auto maskedFillScalar = dyn_cast<AtenMaskedFillScalarOp>(op)) {
|
||||||
AtenMaskedFillScalarOp::Adaptor adaptor(operands);
|
AtenMaskedFillScalarOp::Adaptor adaptor(operands);
|
||||||
Type dtype = converter->convertType(maskedFill.getType())
|
Type dtype = converter->convertType(maskedFillScalar.getType())
|
||||||
.cast<RankedTensorType>()
|
.cast<RankedTensorType>()
|
||||||
.getElementType();
|
.getElementType();
|
||||||
|
|
||||||
|
@ -896,6 +896,17 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
|
|
||||||
return b.create<arith::SelectOp>(loc, mask, fillValue, input);
|
return b.create<arith::SelectOp>(loc, mask, fillValue, input);
|
||||||
}
|
}
|
||||||
|
if (auto maskedFillTensor = dyn_cast<AtenMaskedFillTensorOp>(op)) {
|
||||||
|
AtenMaskedFillScalarOp::Adaptor adaptor(operands);
|
||||||
|
Type dtype = converter->convertType(maskedFillTensor.getType())
|
||||||
|
.cast<RankedTensorType>()
|
||||||
|
.getElementType();
|
||||||
|
|
||||||
|
Value input = payloadArgs[0];
|
||||||
|
Value mask = payloadArgs[1];
|
||||||
|
Value fillValue = convertScalarToDtype(b, loc, payloadArgs[2], dtype);
|
||||||
|
return b.create<arith::SelectOp>(loc, mask, fillValue, input);
|
||||||
|
}
|
||||||
|
|
||||||
if (auto triu = dyn_cast<AtenTriuOp>(op)) {
|
if (auto triu = dyn_cast<AtenTriuOp>(op)) {
|
||||||
// Check if the rank of the input tensor is valid.
|
// Check if the rank of the input tensor is valid.
|
||||||
|
@ -970,7 +981,7 @@ public:
|
||||||
AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp,
|
AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp,
|
||||||
AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp,
|
AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp,
|
||||||
AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
|
AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
|
||||||
AtenLogicalOrOp, AtenTriuOp>(op))
|
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp>(op))
|
||||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||||
|
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
|
@ -1708,7 +1719,8 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
||||||
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
|
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
|
||||||
AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp,
|
AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp,
|
||||||
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
|
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
|
||||||
AtenNeScalarOp, AtenMaskedFillScalarOp, AtenLogicalOrOp, AtenTriuOp, AtenRemainderScalarOp>();
|
AtenNeScalarOp, AtenMaskedFillScalarOp, AtenMaskedFillTensorOp,
|
||||||
|
AtenLogicalOrOp, AtenTriuOp, AtenRemainderScalarOp>();
|
||||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenNllLossForwardOp>();
|
target.addIllegalOp<AtenNllLossForwardOp>();
|
||||||
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
|
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
|
||||||
|
|
|
@ -658,7 +658,8 @@ void TypeAnalysis::visitOperation(Operation *op,
|
||||||
AtenZero_Op, AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp,
|
AtenZero_Op, AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp,
|
||||||
AtenIndexPutOp, ValsemVariantAtenCopyOp, AtenZeroOp,
|
AtenIndexPutOp, ValsemVariantAtenCopyOp, AtenZeroOp,
|
||||||
AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp,
|
AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp,
|
||||||
PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp>(op)) {
|
PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp, AtenMaskedFillTensorOp>(
|
||||||
|
op)) {
|
||||||
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
|
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6214,6 +6214,10 @@ module {
|
||||||
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||||
return %0 : !torch.list<int>
|
return %0 : !torch.list<int>
|
||||||
}
|
}
|
||||||
|
func.func @"__torch_mlir_shape_fn.aten.masked_fill.Tensor"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>) -> !torch.list<int> {
|
||||||
|
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||||
|
return %0 : !torch.list<int>
|
||||||
|
}
|
||||||
func.func @"__torch_mlir_shape_fn.aten.zero"(%arg0: !torch.list<int>) -> !torch.list<int> {
|
func.func @"__torch_mlir_shape_fn.aten.zero"(%arg0: !torch.list<int>) -> !torch.list<int> {
|
||||||
return %arg0 : !torch.list<int>
|
return %arg0 : !torch.list<int>
|
||||||
}
|
}
|
||||||
|
|
|
@ -777,6 +777,9 @@ def aten〇_to_copy(self: List[int], dtype: Optional[int] = None, layout: Option
|
||||||
def aten〇masked_fill〇Scalar(self: List[int], mask: List[int], value: float) -> List[int]:
|
def aten〇masked_fill〇Scalar(self: List[int], mask: List[int], value: float) -> List[int]:
|
||||||
return upstream_shape_functions.unary(self)
|
return upstream_shape_functions.unary(self)
|
||||||
|
|
||||||
|
def aten〇masked_fill〇Tensor(self: List[int], mask: List[int], value: List[int]) -> List[int]:
|
||||||
|
return upstream_shape_functions.unary(self)
|
||||||
|
|
||||||
def aten〇zero(self: List[int]) -> List[int]:
|
def aten〇zero(self: List[int]) -> List[int]:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
|
@ -279,6 +279,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
"aten::le.Scalar : (Tensor, Scalar) -> (Tensor)",
|
"aten::le.Scalar : (Tensor, Scalar) -> (Tensor)",
|
||||||
"aten::fmod.Scalar : (Tensor, Scalar) -> (Tensor)",
|
"aten::fmod.Scalar : (Tensor, Scalar) -> (Tensor)",
|
||||||
"aten::masked_fill.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)",
|
"aten::masked_fill.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)",
|
||||||
|
"aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)",
|
||||||
"aten::clamp : (Tensor, Scalar?, Scalar?) -> (Tensor)",
|
"aten::clamp : (Tensor, Scalar?, Scalar?) -> (Tensor)",
|
||||||
"aten::clamp_min : (Tensor, Scalar) -> (Tensor)",
|
"aten::clamp_min : (Tensor, Scalar) -> (Tensor)",
|
||||||
"aten::clamp_max : (Tensor, Scalar) -> (Tensor)",
|
"aten::clamp_max : (Tensor, Scalar) -> (Tensor)",
|
||||||
|
|
|
@ -342,7 +342,8 @@ class EmptyLikeMemoryFormatModule(torch.nn.Module):
|
||||||
([-1, -1, -1, -1], torch.float32, True),
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
])
|
])
|
||||||
def forward(self, a):
|
def forward(self, a):
|
||||||
return torch.empty_like(a, memory_format=torch.preserve_format).fill_(0)
|
return torch.empty_like(a,
|
||||||
|
memory_format=torch.preserve_format).fill_(0)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: EmptyLikeMemoryFormatModule())
|
@register_test_case(module_factory=lambda: EmptyLikeMemoryFormatModule())
|
||||||
|
@ -1421,3 +1422,25 @@ class MaskedFillScalarFloatValueModule(torch.nn.Module):
|
||||||
def MaskedFillScalarFloatValueModule_basic(module, tu: TestUtils):
|
def MaskedFillScalarFloatValueModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randint(-10, 10, (2, 3)),
|
module.forward(torch.randint(-10, 10, (2, 3)),
|
||||||
torch.randint(0, 2, (2, 3)).to(dtype=torch.bool))
|
torch.randint(0, 2, (2, 3)).to(dtype=torch.bool))
|
||||||
|
|
||||||
|
|
||||||
|
class MaskedFillTensorFloatValueModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int64, True),
|
||||||
|
([-1, -1], torch.bool, True),
|
||||||
|
([], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x, mask, value):
|
||||||
|
return torch.ops.aten.masked_fill(x, mask, value=value)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: MaskedFillTensorFloatValueModule())
|
||||||
|
def MaskedFillTensorFloatValueModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(-10, 10, (2, 3)),
|
||||||
|
torch.randint(0, 2, (2, 3)).to(dtype=torch.bool), tu.rand())
|
||||||
|
|
Loading…
Reference in New Issue