Add min/max/clamp support.

Part of #380

Also
- BoolType is not considered as Scalar
- e2e framework fixes for nan handling
- `tu.rand(..., low=, high=)` support
- delete unused variable (fix warning)
- Add IouOfModule from #380 to e2e test suite (this is a common
  calculation in vision models)

 Your branch is ahead of 'origin/main' by 1 commit.
pull/363/head snapshot-20211027.48
Sean Silva 2021-10-27 03:44:01 +00:00
parent 029c30c060
commit 30df2ec71b
11 changed files with 262 additions and 12 deletions

View File

@ -238,3 +238,75 @@ def ElementwiseSigmoidModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5))
# ==============================================================================
class ElementwiseMinimumModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
])
def forward(self, x, y):
return torch.minimum(x, y)
@register_test_case(module_factory=lambda: ElementwiseMinimumModule())
def ElementwiseMinimumModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5), tu.rand(3, 5))
module.forward(tu.nans(3, 5), tu.rand(3, 5))
# ==============================================================================
class ElementwiseMaximumModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
])
def forward(self, x, y):
return torch.maximum(x, y)
@register_test_case(module_factory=lambda: ElementwiseMaximumModule())
def ElementwiseMaximumModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5), tu.rand(3, 5))
module.forward(tu.nans(3, 5), tu.rand(3, 5))
# ==============================================================================
class ElementwiseClampModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
# TODO: It would be great to return all of these, so they get checked
# individually, but RefBackend doesn't support multiple returns.
# Instead, multiply them together, which has some chance of propagating
# all the values.
float_min = torch.clamp(x, min=-2.0)
int_min = torch.clamp(x, min=-3)
float_max = torch.clamp(x, max=2.0)
int_max = torch.clamp(x, max=3)
both = torch.clamp(x, min=-5, max=5)
return float_min * int_min * float_max * int_max * both
@register_test_case(module_factory=lambda: ElementwiseClampModule())
def ElementwiseClampModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5, low=-10, high=10))

View File

@ -30,3 +30,30 @@ class ResNet18Module(torch.nn.Module):
@register_test_case(module_factory=lambda: ResNet18Module())
def ResNet18Module_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 3, 224, 224))
class IouOfModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
])
def forward(self, bbox1, bbox2):
area1 = (bbox1[:, 2] - bbox1[:, 0]) * (bbox1[:, 3] - bbox1[:, 1])
area2 = (bbox2[:, 2] - bbox2[:, 0]) * (bbox2[:, 3] - bbox2[:, 1])
lt = torch.maximum(bbox1[:, :2], bbox2[:, :2])
rb = torch.minimum(bbox1[:, 2:], bbox2[:, 2:])
overlap_coord = (rb - lt).clip(0)
overlap = overlap_coord[:, 0] * overlap_coord[:, 1]
union = area1 + area2 - overlap
return overlap / union
@register_test_case(module_factory=lambda: IouOfModule())
def IouOfModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1024, 4), tu.rand(1024, 4))

View File

@ -15,6 +15,7 @@
# to the backend contract.
COMMON_TORCH_MLIR_LOWERING_XFAILS = {
"QuantizedMLP_basic",
"IouOfModule_basic",
}
REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS

View File

@ -762,6 +762,68 @@ def Torch_AtenMaskedFill_ScalarOp : Torch_Op<"aten.masked_fill_.Scalar", [
let assemblyFormat = "$self `,` $mask `,` $value attr-dict `:` type($self) `,` type($mask) `,` type($value) `->` type($result)";
}
def Torch_AtenClampOp : Torch_Op<"aten.clamp", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::clamp : (Tensor, Scalar?, Scalar?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalScalarType:$min,
AnyTorchOptionalScalarType:$max
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $min `,` $max attr-dict `:` type($self) `,` type($min) `,` type($max) `->` type($result)";
}
def Torch_AtenClamp_Op : Torch_Op<"aten.clamp_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::clamp_ : (Tensor, Scalar?, Scalar?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalScalarType:$min,
AnyTorchOptionalScalarType:$max
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $min `,` $max attr-dict `:` type($self) `,` type($min) `,` type($max) `->` type($result)";
}
def Torch_AtenMaximumOp : Torch_Op<"aten.maximum", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::maximum : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
}
def Torch_AtenMinimumOp : Torch_Op<"aten.minimum", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::minimum : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
}
def Torch_AtenGeluOp : Torch_Op<"aten.gelu", [
AllowsTypeRefinement,
HasValueSemantics

View File

@ -399,12 +399,14 @@ def AnyTorchOptionalTensorListType :
ListOf<[AnyTorchOptionalTensorType],
"Any optional tensor list type (Tensor?[])">;
// Note: TorchScript does not consider !torch.bool to be a Scalar.
def AnyTorchScalarType : AnyTypeOf<[
Torch_IntType,
Torch_FloatType,
Torch_BoolType,
Torch_NumberType,
], "Any Python numeric type compatible with being the scalar type of a tensor (`Scalar`)">;
def AnyTorchOptionalScalarType:
OptionalOf<AnyTorchScalarType, "Optional torch scalar type">;
// See function `DictTypePtr create(TypePtr key, TypePtr value)`
// in aten/src/ATen/core/jit_type.h.
@ -423,6 +425,7 @@ def AnyTorchType : AnyTypeOf<[
AnyTorchScalarType,
AnyTorchTensorType,
Torch_AnyType,
Torch_BoolType,
Torch_DictType,
Torch_DeviceType,
Torch_ListType,

View File

@ -1259,6 +1259,26 @@ public:
};
} // namespace
static Value promoteScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype) {
// TODO: For the integer case, we probably need the unconverted dtype to
// be able to know if we need signed or unsigned conversion.
if (dtype.isa<mlir::FloatType>()) {
if (scalar.getType().isa<mlir::FloatType>()) {
// `scalar` will always be f64 since that is what the TypeConverter
// converts !torch.float to.
return b.create<arith::TruncFOp>(loc, scalar, dtype);
} else {
assert(scalar.getType().isa<mlir::IntegerType>());
// `scalar` will always be i64 since that is what the TypeConverter
// converts !torch.int to.
return b.create<arith::SIToFPOp>(loc, scalar, dtype);
}
}
mlir::emitError(loc) << "promoteScalarToDtype for dtype " << dtype;
return nullptr;
}
static Value createLinalgPayloadCalculationForElementwiseOp(
OpBuilder &b, Location loc, ValueRange payloadArgs, Operation *op,
ArrayRef<Value> operands) {
@ -1373,6 +1393,59 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
auto weightedDelta = b.create<arith::MulFOp>(loc, delta, weight);
return b.create<arith::AddFOp>(loc, start, weightedDelta);
}
if (auto minimum = dyn_cast<AtenMinimumOp>(op)) {
if (!minimum.getType()
.cast<ValueTensorType>()
.getDtype()
.isa<mlir::FloatType>()) {
minimum.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
payloadArgs[0], payloadArgs[1]);
return b.create<SelectOp>(loc, pred, payloadArgs[0], payloadArgs[1]);
}
if (auto maximum = dyn_cast<AtenMaximumOp>(op)) {
if (!maximum.getType()
.cast<ValueTensorType>()
.getDtype()
.isa<mlir::FloatType>()) {
maximum.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
payloadArgs[0], payloadArgs[1]);
return b.create<SelectOp>(loc, pred, payloadArgs[0], payloadArgs[1]);
}
if (auto clamp = dyn_cast<AtenClampOp>(op)) {
auto dtype = clamp.getType().cast<ValueTensorType>().getDtype();
if (!dtype.isa<mlir::FloatType>()) {
clamp.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
AtenClampOp::Adaptor adaptor(operands);
auto min = adaptor.min();
auto max = adaptor.max();
if (min.getType().isa<Torch::OptionalType>() ||
max.getType().isa<Torch::OptionalType>()) {
clamp.emitError("unimplemented: runtime optional type");
return nullptr;
}
auto result = payloadArgs[0];
if (!min.getType().isa<Torch::NoneType>()) {
auto minPromoted = promoteScalarToDtype(b, loc, min, dtype);
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
result, minPromoted);
result = b.create<SelectOp>(loc, pred, minPromoted, result);
}
if (!max.getType().isa<Torch::NoneType>()) {
auto maxPromoted = promoteScalarToDtype(b, loc, max, dtype);
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
result, maxPromoted);
result = b.create<SelectOp>(loc, pred, maxPromoted, result);
}
return result;
}
op->emitError("unimplemented lowering in "
"createLinalgPayloadCalculationForElementwiseOp");
return nullptr;
@ -1581,7 +1654,8 @@ struct ConvertElementwiseOp : ConversionPattern {
ConversionPatternRewriter &rewriter) const override {
if (!isa<AtenTanhOp, AtenReluOp, AtenGeluOp, AtenAddTensorOp,
AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp>(op))
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenMinimumOp,
AtenMaximumOp, AtenClampOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@ -2509,7 +2583,8 @@ public:
patterns.add<ConvertAtenBatchNormOp>(typeConverter, context);
target.addIllegalOp<AtenTanhOp, AtenReluOp, AtenGeluOp, AtenAddTensorOp,
AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
AtenLerpTensorOp, AtenSigmoidOp>();
AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp,
AtenMaximumOp, AtenClampOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenUnsqueezeOp>();
patterns.add<ConvertAtenUnsqueezeOp>(typeConverter, context);

View File

@ -95,7 +95,6 @@ public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenMatmulOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value lhs = op.self();
Value rhs = op.other();

View File

@ -198,7 +198,7 @@ public:
DerefineOp, AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp,
AtenFill_ScalarOp, AtenDetachOp, AtenMaskedFill_ScalarOp,
AtenCopy_Op, AtenIndexPut_Op, AtenCopy_Op, AtenCumsumOp,
AtenLayerNormOp>(op)) {
AtenLayerNormOp, AtenClampOp>(op)) {
return getLatticeElement(op->getResult(0)).join(*operands[0]);
}
@ -252,7 +252,8 @@ public:
} else if (auto avgPool2d = llvm::dyn_cast<AtenAdaptiveAvgPool2dOp>(op)) {
return visitAtenAdaptiveAvgPool2dOp(avgPool2d, operands);
} else if (isa<AtenAddTensorOp, AtenSubTensorOp, AtenMulTensorOp,
AtenDivTensorOp, Aten__And__TensorOp, AtenEqTensorOp>(op)) {
AtenDivTensorOp, Aten__And__TensorOp, AtenEqTensorOp,
AtenMinimumOp, AtenMaximumOp>(op)) {
return visitBinaryBroadcastingOp(op, operands);
} else if (auto lerpTensor = llvm::dyn_cast<AtenLerpTensorOp>(op)) {
return visitAtenLerpTensorOp(lerpTensor, operands);

View File

@ -222,6 +222,7 @@ TORCH_TYPE_TO_ODS_TYPE = {
"Tensor?[]": "AnyTorchOptionalTensorListType",
"Tensor[]": "AnyTorchTensorListType",
"Scalar": "AnyTorchScalarType",
"Scalar?": "AnyTorchOptionalScalarType",
"int": "Torch_IntType",
"int[]": "TorchIntListType",
"int?": "TorchOptionalIntType",
@ -460,9 +461,15 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
"aten::gt.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::ge.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::fmod.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::masked_fill.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)"
"aten::masked_fill.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)",
"aten::clamp : (Tensor, Scalar?, Scalar?) -> (Tensor)",
]:
emit_with_mutating_variants(key)
# Elementwise tensor compute ops that don't have the standard mutating
# variants.
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")
emit("aten::minimum : (Tensor, Tensor) -> (Tensor)")
emit("aten::gelu : (Tensor) -> (Tensor)")

View File

@ -148,10 +148,13 @@ class TestUtils:
torch.manual_seed(0)
# TODO: Add zeros/ones/etc. as convenient.
def rand(self, *sizes):
if len(sizes) == 0:
return torch.rand([])
return torch.rand(*sizes)
def rand(self, *sizes, low=0.0, high=1.0):
return torch.empty(sizes).uniform_(low, high)
def nans(self, *sizes):
vals = torch.empty(sizes)
vals[...] = torch.nan
return vals
class Test(NamedTuple):

View File

@ -152,7 +152,7 @@ class ValueReport:
return self._record_failure(
f'shape ({value.shape}) is not equal to golden shape ({golden.shape})'
)
if not torch.allclose(value, golden, rtol=1e-03, atol=1e-07):
if not torch.allclose(value, golden, rtol=1e-03, atol=1e-07, equal_nan=True):
return self._record_failure(
f'value ({TensorSummary(value)}) is not close to golden value ({TensorSummary(golden)})'
)