mirror of https://github.com/llvm/torch-mlir
Add min/max/clamp support.
Part of #380 Also - BoolType is not considered as Scalar - e2e framework fixes for nan handling - `tu.rand(..., low=, high=)` support - delete unused variable (fix warning) - Add IouOfModule from #380 to e2e test suite (this is a common calculation in vision models) Your branch is ahead of 'origin/main' by 1 commit.pull/363/head snapshot-20211027.48
parent
029c30c060
commit
30df2ec71b
|
@ -238,3 +238,75 @@ def ElementwiseSigmoidModule_basic(module, tu: TestUtils):
|
|||
module.forward(tu.rand(3, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseMinimumModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.minimum(x, y)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseMinimumModule())
|
||||
def ElementwiseMinimumModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5), tu.rand(3, 5))
|
||||
module.forward(tu.nans(3, 5), tu.rand(3, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseMaximumModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.maximum(x, y)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseMaximumModule())
|
||||
def ElementwiseMaximumModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5), tu.rand(3, 5))
|
||||
module.forward(tu.nans(3, 5), tu.rand(3, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseClampModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
# TODO: It would be great to return all of these, so they get checked
|
||||
# individually, but RefBackend doesn't support multiple returns.
|
||||
# Instead, multiply them together, which has some chance of propagating
|
||||
# all the values.
|
||||
float_min = torch.clamp(x, min=-2.0)
|
||||
int_min = torch.clamp(x, min=-3)
|
||||
float_max = torch.clamp(x, max=2.0)
|
||||
int_max = torch.clamp(x, max=3)
|
||||
both = torch.clamp(x, min=-5, max=5)
|
||||
return float_min * int_min * float_max * int_max * both
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseClampModule())
|
||||
def ElementwiseClampModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5, low=-10, high=10))
|
||||
|
|
|
@ -30,3 +30,30 @@ class ResNet18Module(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: ResNet18Module())
|
||||
def ResNet18Module_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 3, 224, 224))
|
||||
|
||||
|
||||
class IouOfModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, bbox1, bbox2):
|
||||
area1 = (bbox1[:, 2] - bbox1[:, 0]) * (bbox1[:, 3] - bbox1[:, 1])
|
||||
area2 = (bbox2[:, 2] - bbox2[:, 0]) * (bbox2[:, 3] - bbox2[:, 1])
|
||||
lt = torch.maximum(bbox1[:, :2], bbox2[:, :2])
|
||||
rb = torch.minimum(bbox1[:, 2:], bbox2[:, 2:])
|
||||
|
||||
overlap_coord = (rb - lt).clip(0)
|
||||
overlap = overlap_coord[:, 0] * overlap_coord[:, 1]
|
||||
union = area1 + area2 - overlap
|
||||
|
||||
return overlap / union
|
||||
|
||||
@register_test_case(module_factory=lambda: IouOfModule())
|
||||
def IouOfModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1024, 4), tu.rand(1024, 4))
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
# to the backend contract.
|
||||
COMMON_TORCH_MLIR_LOWERING_XFAILS = {
|
||||
"QuantizedMLP_basic",
|
||||
"IouOfModule_basic",
|
||||
}
|
||||
|
||||
REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS
|
||||
|
|
|
@ -762,6 +762,68 @@ def Torch_AtenMaskedFill_ScalarOp : Torch_Op<"aten.masked_fill_.Scalar", [
|
|||
let assemblyFormat = "$self `,` $mask `,` $value attr-dict `:` type($self) `,` type($mask) `,` type($value) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenClampOp : Torch_Op<"aten.clamp", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::clamp : (Tensor, Scalar?, Scalar?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchOptionalScalarType:$min,
|
||||
AnyTorchOptionalScalarType:$max
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $min `,` $max attr-dict `:` type($self) `,` type($min) `,` type($max) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenClamp_Op : Torch_Op<"aten.clamp_", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::clamp_ : (Tensor, Scalar?, Scalar?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchOptionalScalarType:$min,
|
||||
AnyTorchOptionalScalarType:$max
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $min `,` $max attr-dict `:` type($self) `,` type($min) `,` type($max) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenMaximumOp : Torch_Op<"aten.maximum", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::maximum : (Tensor, Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$other
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenMinimumOp : Torch_Op<"aten.minimum", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::minimum : (Tensor, Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$other
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenGeluOp : Torch_Op<"aten.gelu", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
|
|
@ -399,12 +399,14 @@ def AnyTorchOptionalTensorListType :
|
|||
ListOf<[AnyTorchOptionalTensorType],
|
||||
"Any optional tensor list type (Tensor?[])">;
|
||||
|
||||
// Note: TorchScript does not consider !torch.bool to be a Scalar.
|
||||
def AnyTorchScalarType : AnyTypeOf<[
|
||||
Torch_IntType,
|
||||
Torch_FloatType,
|
||||
Torch_BoolType,
|
||||
Torch_NumberType,
|
||||
], "Any Python numeric type compatible with being the scalar type of a tensor (`Scalar`)">;
|
||||
def AnyTorchOptionalScalarType:
|
||||
OptionalOf<AnyTorchScalarType, "Optional torch scalar type">;
|
||||
|
||||
// See function `DictTypePtr create(TypePtr key, TypePtr value)`
|
||||
// in aten/src/ATen/core/jit_type.h.
|
||||
|
@ -423,6 +425,7 @@ def AnyTorchType : AnyTypeOf<[
|
|||
AnyTorchScalarType,
|
||||
AnyTorchTensorType,
|
||||
Torch_AnyType,
|
||||
Torch_BoolType,
|
||||
Torch_DictType,
|
||||
Torch_DeviceType,
|
||||
Torch_ListType,
|
||||
|
|
|
@ -1259,6 +1259,26 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
|
||||
static Value promoteScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype) {
|
||||
// TODO: For the integer case, we probably need the unconverted dtype to
|
||||
// be able to know if we need signed or unsigned conversion.
|
||||
if (dtype.isa<mlir::FloatType>()) {
|
||||
if (scalar.getType().isa<mlir::FloatType>()) {
|
||||
// `scalar` will always be f64 since that is what the TypeConverter
|
||||
// converts !torch.float to.
|
||||
return b.create<arith::TruncFOp>(loc, scalar, dtype);
|
||||
} else {
|
||||
assert(scalar.getType().isa<mlir::IntegerType>());
|
||||
// `scalar` will always be i64 since that is what the TypeConverter
|
||||
// converts !torch.int to.
|
||||
return b.create<arith::SIToFPOp>(loc, scalar, dtype);
|
||||
}
|
||||
}
|
||||
mlir::emitError(loc) << "promoteScalarToDtype for dtype " << dtype;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||
OpBuilder &b, Location loc, ValueRange payloadArgs, Operation *op,
|
||||
ArrayRef<Value> operands) {
|
||||
|
@ -1373,6 +1393,59 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
auto weightedDelta = b.create<arith::MulFOp>(loc, delta, weight);
|
||||
return b.create<arith::AddFOp>(loc, start, weightedDelta);
|
||||
}
|
||||
if (auto minimum = dyn_cast<AtenMinimumOp>(op)) {
|
||||
if (!minimum.getType()
|
||||
.cast<ValueTensorType>()
|
||||
.getDtype()
|
||||
.isa<mlir::FloatType>()) {
|
||||
minimum.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
|
||||
payloadArgs[0], payloadArgs[1]);
|
||||
return b.create<SelectOp>(loc, pred, payloadArgs[0], payloadArgs[1]);
|
||||
}
|
||||
if (auto maximum = dyn_cast<AtenMaximumOp>(op)) {
|
||||
if (!maximum.getType()
|
||||
.cast<ValueTensorType>()
|
||||
.getDtype()
|
||||
.isa<mlir::FloatType>()) {
|
||||
maximum.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
|
||||
payloadArgs[0], payloadArgs[1]);
|
||||
return b.create<SelectOp>(loc, pred, payloadArgs[0], payloadArgs[1]);
|
||||
}
|
||||
if (auto clamp = dyn_cast<AtenClampOp>(op)) {
|
||||
auto dtype = clamp.getType().cast<ValueTensorType>().getDtype();
|
||||
if (!dtype.isa<mlir::FloatType>()) {
|
||||
clamp.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
AtenClampOp::Adaptor adaptor(operands);
|
||||
auto min = adaptor.min();
|
||||
auto max = adaptor.max();
|
||||
if (min.getType().isa<Torch::OptionalType>() ||
|
||||
max.getType().isa<Torch::OptionalType>()) {
|
||||
clamp.emitError("unimplemented: runtime optional type");
|
||||
return nullptr;
|
||||
}
|
||||
auto result = payloadArgs[0];
|
||||
if (!min.getType().isa<Torch::NoneType>()) {
|
||||
auto minPromoted = promoteScalarToDtype(b, loc, min, dtype);
|
||||
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
|
||||
result, minPromoted);
|
||||
result = b.create<SelectOp>(loc, pred, minPromoted, result);
|
||||
}
|
||||
if (!max.getType().isa<Torch::NoneType>()) {
|
||||
auto maxPromoted = promoteScalarToDtype(b, loc, max, dtype);
|
||||
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
|
||||
result, maxPromoted);
|
||||
result = b.create<SelectOp>(loc, pred, maxPromoted, result);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
op->emitError("unimplemented lowering in "
|
||||
"createLinalgPayloadCalculationForElementwiseOp");
|
||||
return nullptr;
|
||||
|
@ -1581,7 +1654,8 @@ struct ConvertElementwiseOp : ConversionPattern {
|
|||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (!isa<AtenTanhOp, AtenReluOp, AtenGeluOp, AtenAddTensorOp,
|
||||
AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
|
||||
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp>(op))
|
||||
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenMinimumOp,
|
||||
AtenMaximumOp, AtenClampOp>(op))
|
||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
|
@ -2509,7 +2583,8 @@ public:
|
|||
patterns.add<ConvertAtenBatchNormOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenTanhOp, AtenReluOp, AtenGeluOp, AtenAddTensorOp,
|
||||
AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
|
||||
AtenLerpTensorOp, AtenSigmoidOp>();
|
||||
AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp,
|
||||
AtenMaximumOp, AtenClampOp>();
|
||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenUnsqueezeOp>();
|
||||
patterns.add<ConvertAtenUnsqueezeOp>(typeConverter, context);
|
||||
|
|
|
@ -95,7 +95,6 @@ public:
|
|||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenMatmulOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value lhs = op.self();
|
||||
Value rhs = op.other();
|
||||
|
||||
|
|
|
@ -198,7 +198,7 @@ public:
|
|||
DerefineOp, AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp,
|
||||
AtenFill_ScalarOp, AtenDetachOp, AtenMaskedFill_ScalarOp,
|
||||
AtenCopy_Op, AtenIndexPut_Op, AtenCopy_Op, AtenCumsumOp,
|
||||
AtenLayerNormOp>(op)) {
|
||||
AtenLayerNormOp, AtenClampOp>(op)) {
|
||||
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
||||
}
|
||||
|
||||
|
@ -252,7 +252,8 @@ public:
|
|||
} else if (auto avgPool2d = llvm::dyn_cast<AtenAdaptiveAvgPool2dOp>(op)) {
|
||||
return visitAtenAdaptiveAvgPool2dOp(avgPool2d, operands);
|
||||
} else if (isa<AtenAddTensorOp, AtenSubTensorOp, AtenMulTensorOp,
|
||||
AtenDivTensorOp, Aten__And__TensorOp, AtenEqTensorOp>(op)) {
|
||||
AtenDivTensorOp, Aten__And__TensorOp, AtenEqTensorOp,
|
||||
AtenMinimumOp, AtenMaximumOp>(op)) {
|
||||
return visitBinaryBroadcastingOp(op, operands);
|
||||
} else if (auto lerpTensor = llvm::dyn_cast<AtenLerpTensorOp>(op)) {
|
||||
return visitAtenLerpTensorOp(lerpTensor, operands);
|
||||
|
|
|
@ -222,6 +222,7 @@ TORCH_TYPE_TO_ODS_TYPE = {
|
|||
"Tensor?[]": "AnyTorchOptionalTensorListType",
|
||||
"Tensor[]": "AnyTorchTensorListType",
|
||||
"Scalar": "AnyTorchScalarType",
|
||||
"Scalar?": "AnyTorchOptionalScalarType",
|
||||
"int": "Torch_IntType",
|
||||
"int[]": "TorchIntListType",
|
||||
"int?": "TorchOptionalIntType",
|
||||
|
@ -460,9 +461,15 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
"aten::gt.Scalar : (Tensor, Scalar) -> (Tensor)",
|
||||
"aten::ge.Scalar : (Tensor, Scalar) -> (Tensor)",
|
||||
"aten::fmod.Scalar : (Tensor, Scalar) -> (Tensor)",
|
||||
"aten::masked_fill.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)"
|
||||
"aten::masked_fill.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)",
|
||||
"aten::clamp : (Tensor, Scalar?, Scalar?) -> (Tensor)",
|
||||
]:
|
||||
emit_with_mutating_variants(key)
|
||||
# Elementwise tensor compute ops that don't have the standard mutating
|
||||
# variants.
|
||||
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::minimum : (Tensor, Tensor) -> (Tensor)")
|
||||
|
||||
|
||||
emit("aten::gelu : (Tensor) -> (Tensor)")
|
||||
|
||||
|
|
|
@ -148,10 +148,13 @@ class TestUtils:
|
|||
torch.manual_seed(0)
|
||||
|
||||
# TODO: Add zeros/ones/etc. as convenient.
|
||||
def rand(self, *sizes):
|
||||
if len(sizes) == 0:
|
||||
return torch.rand([])
|
||||
return torch.rand(*sizes)
|
||||
def rand(self, *sizes, low=0.0, high=1.0):
|
||||
return torch.empty(sizes).uniform_(low, high)
|
||||
|
||||
def nans(self, *sizes):
|
||||
vals = torch.empty(sizes)
|
||||
vals[...] = torch.nan
|
||||
return vals
|
||||
|
||||
|
||||
class Test(NamedTuple):
|
||||
|
|
|
@ -152,7 +152,7 @@ class ValueReport:
|
|||
return self._record_failure(
|
||||
f'shape ({value.shape}) is not equal to golden shape ({golden.shape})'
|
||||
)
|
||||
if not torch.allclose(value, golden, rtol=1e-03, atol=1e-07):
|
||||
if not torch.allclose(value, golden, rtol=1e-03, atol=1e-07, equal_nan=True):
|
||||
return self._record_failure(
|
||||
f'value ({TensorSummary(value)}) is not close to golden value ({TensorSummary(golden)})'
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue