mirror of https://github.com/llvm/torch-mlir
Revert "[LINALG] Decompose `aten.batch_norm` into `aten.native_batch_norm`"
This reverts commit 442ff4605c
.
pull/599/head
parent
ba29d4f250
commit
056cd2078d
|
@ -89,32 +89,6 @@ def BatchNorm3DModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class BatchNormModule(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
@export
|
|
||||||
@annotate_args([
|
|
||||||
None,
|
|
||||||
([-1, -1, -1, -1], torch.float32, True),
|
|
||||||
([-1], torch.float32, True),
|
|
||||||
([-1], torch.float32, True),
|
|
||||||
([-1], torch.float32, True),
|
|
||||||
([-1], torch.float32, True),
|
|
||||||
])
|
|
||||||
def forward(self, x, weight, bias, running_mean, running_var):
|
|
||||||
return torch.ops.aten.batch_norm(
|
|
||||||
x, weight, bias, running_mean, running_var, training=False,
|
|
||||||
momentum=0.1, eps=0.00001, cudnn_enabled=False)
|
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: BatchNormModule())
|
|
||||||
def BatchNormModule_basic(module, tu: TestUtils):
|
|
||||||
module.forward(
|
|
||||||
tu.rand(2, 5, 3, 2), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5))
|
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
|
|
||||||
class NativeBatchNorm1DModule(torch.nn.Module):
|
class NativeBatchNorm1DModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -193,7 +167,7 @@ def NativeBatchNorm3DModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class NativeBatchNormWeightNoneModule(torch.nn.Module):
|
class NativeBatchNormNoneWeightModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -207,39 +181,16 @@ class NativeBatchNormWeightNoneModule(torch.nn.Module):
|
||||||
])
|
])
|
||||||
def forward(self, x, bias, running_mean, running_var):
|
def forward(self, x, bias, running_mean, running_var):
|
||||||
return torch.ops.aten.native_batch_norm(
|
return torch.ops.aten.native_batch_norm(
|
||||||
x, weight=None, bias=bias, running_mean=running_mean,
|
x, None, bias, running_mean, running_var, training=False,
|
||||||
running_var=running_var, training=False, momentum=0.1, eps=0.00001)
|
momentum=0.1, eps=0.00001)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: NativeBatchNormWeightNoneModule())
|
@register_test_case(module_factory=lambda: NativeBatchNormNoneWeightModule())
|
||||||
def NativeBatchNormWeightNoneModule_basic(module, tu: TestUtils):
|
def NativeBatchNormNoneWeightModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(5), tu.rand(5), tu.rand(5))
|
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(5), tu.rand(5), tu.rand(5))
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class NativeBatchNormWeightNoneBiasNoneModule(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
@export
|
|
||||||
@annotate_args([
|
|
||||||
None,
|
|
||||||
([-1, -1], torch.float32, True),
|
|
||||||
([-1], torch.float32, True),
|
|
||||||
([-1], torch.float32, True),
|
|
||||||
])
|
|
||||||
def forward(self, x, running_mean, running_var):
|
|
||||||
return torch.ops.aten.native_batch_norm(
|
|
||||||
x, weight=None, bias=None, running_mean=running_mean,
|
|
||||||
running_var=running_var, training=False, momentum=0.1, eps=0.00001)
|
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: NativeBatchNormWeightNoneBiasNoneModule())
|
|
||||||
def NativeBatchNormWeightNoneBiasNoneModule_basic(module, tu: TestUtils):
|
|
||||||
module.forward(tu.rand(2, 5), tu.rand(5), tu.rand(5))
|
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
|
|
||||||
class NativeLayerNormModule(torch.nn.Module):
|
class NativeLayerNormModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -86,11 +86,15 @@ TOSA_PASS_SET = {
|
||||||
"ElementwiseReciprocalModule_basic",
|
"ElementwiseReciprocalModule_basic",
|
||||||
"TypePromotionAlphaWiderModule_basic",
|
"TypePromotionAlphaWiderModule_basic",
|
||||||
"Conv2dWithPaddingDilationStrideStaticModule_basic",
|
"Conv2dWithPaddingDilationStrideStaticModule_basic",
|
||||||
|
"BatchNorm1DModule_basic",
|
||||||
|
"BatchNorm2DModule_basic",
|
||||||
|
"BatchNorm3DModule_basic",
|
||||||
"FlattenStaticModule_basic",
|
"FlattenStaticModule_basic",
|
||||||
"FlattenRank0Module_basic",
|
"FlattenRank0Module_basic",
|
||||||
"ElementwiseFlattenBroadcastModule_basic",
|
"ElementwiseFlattenBroadcastModule_basic",
|
||||||
"SquareModule_basic",
|
"SquareModule_basic",
|
||||||
"MaxPool2dStaticModule_basic",
|
"MaxPool2dStaticModule_basic",
|
||||||
|
"ResNet18StaticModule_basic",
|
||||||
"NativeLayerNormModule4D_basic",
|
"NativeLayerNormModule4D_basic",
|
||||||
"LayerNormNormalizeOverAllDimsModule_basic",
|
"LayerNormNormalizeOverAllDimsModule_basic",
|
||||||
"PermuteModule_basic",
|
"PermuteModule_basic",
|
||||||
|
|
|
@ -9,7 +9,6 @@
|
||||||
#ifndef TORCHMLIR_DIALECT_TORCH_UTILS_H
|
#ifndef TORCHMLIR_DIALECT_TORCH_UTILS_H
|
||||||
#define TORCHMLIR_DIALECT_TORCH_UTILS_H
|
#define TORCHMLIR_DIALECT_TORCH_UTILS_H
|
||||||
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
||||||
|
@ -23,7 +22,6 @@ int64_t toPositiveDim(int64_t dim, int64_t inputRank);
|
||||||
bool isValidDim(int64_t dim, int64_t inputRank);
|
bool isValidDim(int64_t dim, int64_t inputRank);
|
||||||
bool getListConstructElements(Value v, SmallVectorImpl<Value> &elems);
|
bool getListConstructElements(Value v, SmallVectorImpl<Value> &elems);
|
||||||
torch_upstream::ScalarType getScalarTypeForType(Type type);
|
torch_upstream::ScalarType getScalarTypeForType(Type type);
|
||||||
LogicalResult checkNotNone(PatternRewriter &rewriter, Operation *op, Value v);
|
|
||||||
|
|
||||||
} // namespace Torch
|
} // namespace Torch
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
@ -70,6 +70,15 @@ static LogicalResult verifyLinalgCompatibleTypes(Operation *op,
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static LogicalResult checkNotNone(PatternRewriter &rewriter, Operation *op,
|
||||||
|
Value v) {
|
||||||
|
Type type = v.getType();
|
||||||
|
if (type.isa<OptionalType>() || type.isa<Torch::NoneType>() ||
|
||||||
|
type.isa<mlir::NoneType>())
|
||||||
|
return rewriter.notifyMatchFailure(op, "unimplemented None type arg");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
// Generate IR: dim = dim >= 0 ? dim : dim + inputRank
|
// Generate IR: dim = dim >= 0 ? dim : dim + inputRank
|
||||||
static Value toPositiveDimDynamic(OpBuilder &b, Location loc, Value dim,
|
static Value toPositiveDimDynamic(OpBuilder &b, Location loc, Value dim,
|
||||||
Value inputRank) {
|
Value inputRank) {
|
||||||
|
@ -599,6 +608,111 @@ static void createLinalgPayloadCalculationForGatherOps(
|
||||||
b.create<linalg::YieldOp>(loc, extract);
|
b.create<linalg::YieldOp>(loc, extract);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ConvertAtenBatchNormOp : public OpConversionPattern<AtenBatchNormOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenBatchNormOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
MLIRContext *context = op->getContext();
|
||||||
|
Location loc = op->getLoc();
|
||||||
|
Value input = adaptor.input();
|
||||||
|
Value weight = adaptor.weight();
|
||||||
|
Value bias = adaptor.bias();
|
||||||
|
Value runningMean = adaptor.running_mean();
|
||||||
|
Value runningVar = adaptor.running_var();
|
||||||
|
Value training = adaptor.training();
|
||||||
|
Value eps = adaptor.eps();
|
||||||
|
|
||||||
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// TODO: Handle the None cases for the optional parameters:
|
||||||
|
// weight, bias.
|
||||||
|
if (failed(checkNotNone(rewriter, op, weight)) ||
|
||||||
|
failed(checkNotNone(rewriter, op, bias)) ||
|
||||||
|
failed(checkNotNone(rewriter, op, runningMean)) ||
|
||||||
|
failed(checkNotNone(rewriter, op, runningVar)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto inputType = input.getType().cast<RankedTensorType>();
|
||||||
|
auto weightType = weight.getType().cast<RankedTensorType>();
|
||||||
|
auto biasType = bias.getType().cast<RankedTensorType>();
|
||||||
|
auto runningMeanType = runningMean.getType().cast<RankedTensorType>();
|
||||||
|
auto runningVarType = runningVar.getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
|
auto inputRank = inputType.getRank();
|
||||||
|
if (inputRank <= 2)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "input should have rank larger than 2");
|
||||||
|
|
||||||
|
if (weightType.getRank() != 1 || biasType.getRank() != 1 ||
|
||||||
|
runningMeanType.getRank() != 1 || runningVarType.getRank() != 1) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "expect weight, bias, running_mean and running_var to be rank 1");
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Add support for training.
|
||||||
|
auto constFalse = rewriter.create<arith::ConstantOp>(
|
||||||
|
loc, IntegerAttr::get(IntegerType::get(context, 1), 0));
|
||||||
|
auto trainingFalse = rewriter.create<arith::CmpIOp>(
|
||||||
|
loc, arith::CmpIPredicate::eq, training, constFalse);
|
||||||
|
rewriter.create<cf::AssertOp>(
|
||||||
|
loc, trainingFalse,
|
||||||
|
rewriter.getStringAttr("training is not supported for now"));
|
||||||
|
|
||||||
|
// num_features – C from an expected input of size (N,C,D,H,W ...)
|
||||||
|
Value numFeatures = rewriter.create<tensor::DimOp>(loc, input, 1);
|
||||||
|
auto contractingDim0EqualsNumFeatures = [&](Value v) {
|
||||||
|
auto dim0 = rewriter.create<tensor::DimOp>(loc, v, 0);
|
||||||
|
auto dim0Equal = rewriter.create<arith::CmpIOp>(
|
||||||
|
loc, arith::CmpIPredicate::eq, numFeatures, dim0);
|
||||||
|
rewriter.create<cf::AssertOp>(
|
||||||
|
loc, dim0Equal,
|
||||||
|
rewriter.getStringAttr(
|
||||||
|
"expect the size of dim 0 equal to the number of features"));
|
||||||
|
};
|
||||||
|
contractingDim0EqualsNumFeatures(weight);
|
||||||
|
contractingDim0EqualsNumFeatures(bias);
|
||||||
|
contractingDim0EqualsNumFeatures(runningMean);
|
||||||
|
contractingDim0EqualsNumFeatures(runningVar);
|
||||||
|
|
||||||
|
auto indexingMap = AffineMap::get(
|
||||||
|
/*dimCount=*/inputRank,
|
||||||
|
/*symbolCount=*/0, rewriter.getAffineDimExpr(1), context);
|
||||||
|
SmallVector<AffineMap> indexingMaps = {
|
||||||
|
rewriter.getMultiDimIdentityMap(inputRank), // input
|
||||||
|
indexingMap, // weight
|
||||||
|
indexingMap, // bias
|
||||||
|
indexingMap, // runningMean
|
||||||
|
indexingMap, // runningVar
|
||||||
|
rewriter.getMultiDimIdentityMap(inputRank), // output
|
||||||
|
};
|
||||||
|
SmallVector<StringRef> iteratorTypes(inputRank, "parallel");
|
||||||
|
Value batchNorm =
|
||||||
|
rewriter
|
||||||
|
.create<linalg::GenericOp>(
|
||||||
|
loc, input.getType(),
|
||||||
|
ValueRange{input, weight, bias, runningMean, runningVar}, input,
|
||||||
|
/*indexingMaps=*/indexingMaps,
|
||||||
|
/*iteratorTypes=*/iteratorTypes,
|
||||||
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
|
Value input = args[0], weight = args[1], bias = args[2],
|
||||||
|
mean = args[3], var = args[4];
|
||||||
|
Value result = createLinalgPayloadCalculationForNormOps(
|
||||||
|
b, loc, var.getType(), input, mean, var, eps, weight,
|
||||||
|
bias);
|
||||||
|
b.create<linalg::YieldOp>(loc, result);
|
||||||
|
})
|
||||||
|
.getResult(0);
|
||||||
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, batchNorm);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// For layernorm, the mean and standard-deviation are calculated separately over
|
// For layernorm, the mean and standard-deviation are calculated separately over
|
||||||
// the last certain number dimensions which have to be of the shape specified by
|
// the last certain number dimensions which have to be of the shape specified by
|
||||||
// normalized_shape.
|
// normalized_shape.
|
||||||
|
@ -4532,6 +4646,8 @@ public:
|
||||||
patterns.add<ConvertAtenBmmOp>(typeConverter, context);
|
patterns.add<ConvertAtenBmmOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenLinearOp>();
|
target.addIllegalOp<AtenLinearOp>();
|
||||||
patterns.add<ConvertAtenLinearOp>(typeConverter, context);
|
patterns.add<ConvertAtenLinearOp>(typeConverter, context);
|
||||||
|
target.addIllegalOp<AtenBatchNormOp>();
|
||||||
|
patterns.add<ConvertAtenBatchNormOp>(typeConverter, context);
|
||||||
target.addIllegalOp<
|
target.addIllegalOp<
|
||||||
AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp,
|
AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp,
|
||||||
AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
|
AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
|
||||||
|
|
|
@ -900,9 +900,7 @@ class DecomposeAtenAddCLikeOp : public OpRewritePattern<OpTy> {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
|
class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
|
||||||
using OpRewritePattern<AtenLayerNormOp>::OpRewritePattern;
|
using OpRewritePattern<AtenLayerNormOp>::OpRewritePattern;
|
||||||
LogicalResult matchAndRewrite(AtenLayerNormOp op,
|
LogicalResult matchAndRewrite(AtenLayerNormOp op,
|
||||||
|
@ -931,40 +929,6 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
|
||||||
class DecomposeAtenBatchNormOp : public OpRewritePattern<AtenBatchNormOp> {
|
|
||||||
using OpRewritePattern<AtenBatchNormOp>::OpRewritePattern;
|
|
||||||
LogicalResult matchAndRewrite(AtenBatchNormOp op,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
// TODO: Add support for `training` mode.
|
|
||||||
bool training = false;
|
|
||||||
if (!matchPattern(op.training(), m_TorchConstantBool(&training)) ||
|
|
||||||
training)
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "unimplemented: training mode is not supported");
|
|
||||||
|
|
||||||
// The `mean` and `invstd` outputs shape should be {0} in the inference
|
|
||||||
// mode.
|
|
||||||
BaseTensorType tensorType = op.getType().cast<BaseTensorType>();
|
|
||||||
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "unimplemented: non-floating point type input");
|
|
||||||
Type emptyType =
|
|
||||||
tensorType.getWithSizesAndDtype({0}, tensorType.getDtype());
|
|
||||||
|
|
||||||
// The first output tensor of the `AtenNativeBatchNormOp` is essentially
|
|
||||||
// `AtenBatchNormOp` result.
|
|
||||||
auto nativeBatchNorm = rewriter.create<AtenNativeBatchNormOp>(
|
|
||||||
op.getLoc(), op.getType(), /*meanType=*/emptyType,
|
|
||||||
/*invStdType=*/emptyType, op.input(), op.weight(), op.bias(),
|
|
||||||
op.running_mean(), op.running_var(), op.training(), op.momentum(),
|
|
||||||
op.eps());
|
|
||||||
rewriter.replaceOp(op, nativeBatchNorm.getResult(0));
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// Decompose `aten.empty_like` op into `aten.size` and `aten.empty` ops.
|
// Decompose `aten.empty_like` op into `aten.size` and `aten.empty` ops.
|
||||||
class DecomposeAtenEmptyLikeOp : public OpRewritePattern<AtenEmptyLikeOp> {
|
class DecomposeAtenEmptyLikeOp : public OpRewritePattern<AtenEmptyLikeOp> {
|
||||||
|
@ -1063,14 +1027,6 @@ class DecomposeAtenNativeBatchNormOp
|
||||||
Value runningVar = op.running_var();
|
Value runningVar = op.running_var();
|
||||||
Value eps = op.eps();
|
Value eps = op.eps();
|
||||||
|
|
||||||
// TODO: Add support for optional type parameters.
|
|
||||||
if (weight.getType().isa<OptionalType>() ||
|
|
||||||
bias.getType().isa<OptionalType>() ||
|
|
||||||
runningMean.getType().isa<OptionalType>() ||
|
|
||||||
runningVar.getType().isa<OptionalType>())
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "unimplemented: optional type arg is not supported");
|
|
||||||
|
|
||||||
// TODO: Add support for `training` mode.
|
// TODO: Add support for `training` mode.
|
||||||
bool training = false;
|
bool training = false;
|
||||||
if (!matchPattern(op.training(), m_TorchConstantBool(&training)) ||
|
if (!matchPattern(op.training(), m_TorchConstantBool(&training)) ||
|
||||||
|
@ -1097,24 +1053,13 @@ class DecomposeAtenNativeBatchNormOp
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "expected running_mean and running_var to be rank 1");
|
op, "expected running_mean and running_var to be rank 1");
|
||||||
|
|
||||||
// The shape of `runningMean` and `runningVar` must be (numFeatures). Here,
|
|
||||||
// 'numFeatures' is C from an expected 'input' of size (N,C,D?,H?,W?).
|
|
||||||
Value zero =
|
Value zero =
|
||||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||||
Value one =
|
Value one =
|
||||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||||
Value numFeatures = rewriter.create<AtenSizeIntOp>(loc, input, /*dim=*/one);
|
Value numFeatures = rewriter.create<AtenSizeIntOp>(loc, input, /*dim=*/one);
|
||||||
auto dim0EqualsNumFeatures = [&](Value v) {
|
// TODO: Add Runtime Asserts to check the shape of weight, bias,
|
||||||
Value dim0 = rewriter.create<AtenSizeIntOp>(loc, v, /*dim=*/zero);
|
// running_mean and running_var to be (numFeatures).
|
||||||
Value eqCmp = rewriter.create<AtenEqIntOp>(loc, BoolType::get(context),
|
|
||||||
dim0, numFeatures);
|
|
||||||
rewriter.create<RuntimeAssertOp>(
|
|
||||||
loc, eqCmp,
|
|
||||||
rewriter.getStringAttr("size of the 0th dimension must be equal to "
|
|
||||||
"the number of features"));
|
|
||||||
};
|
|
||||||
dim0EqualsNumFeatures(runningMean);
|
|
||||||
dim0EqualsNumFeatures(runningVar);
|
|
||||||
|
|
||||||
// The `runningMean` and `runningVar` must be reshaped to (1, C, 1?, 1?, 1?)
|
// The `runningMean` and `runningVar` must be reshaped to (1, C, 1?, 1?, 1?)
|
||||||
// to make it broadcast-compatible with (N, C, D?, H?, W?).
|
// to make it broadcast-compatible with (N, C, D?, H?, W?).
|
||||||
|
@ -1152,22 +1097,18 @@ class DecomposeAtenNativeBatchNormOp
|
||||||
// 3. output = normalizedInput * weight + bias
|
// 3. output = normalizedInput * weight + bias
|
||||||
Value batchNormOutput = normalizedInput;
|
Value batchNormOutput = normalizedInput;
|
||||||
if (!weight.getType().isa<Torch::NoneType>()) {
|
if (!weight.getType().isa<Torch::NoneType>()) {
|
||||||
// The shape of the `weight` tensor must be (numFeatures).
|
// Rank of `weight` must be exactly 1.
|
||||||
if (getTensorRank(weight) != 1)
|
if (getTensorRank(weight) != 1)
|
||||||
return rewriter.notifyMatchFailure(op, "expected weight to be rank 1");
|
return rewriter.notifyMatchFailure(op, "expected weight to be rank 1");
|
||||||
dim0EqualsNumFeatures(weight);
|
|
||||||
|
|
||||||
weight = rewriter.create<AtenViewOp>(loc, reshapeType, weight,
|
weight = rewriter.create<AtenViewOp>(loc, reshapeType, weight,
|
||||||
runningStatsSizeList);
|
runningStatsSizeList);
|
||||||
batchNormOutput = rewriter.create<AtenMulTensorOp>(
|
batchNormOutput = rewriter.create<AtenMulTensorOp>(
|
||||||
loc, batchNormOutput.getType(), batchNormOutput, weight);
|
loc, batchNormOutput.getType(), batchNormOutput, weight);
|
||||||
}
|
}
|
||||||
if (!bias.getType().isa<Torch::NoneType>()) {
|
if (!bias.getType().isa<Torch::NoneType>()) {
|
||||||
// The shape of the `bias` tensor must be (numFeatures).
|
// Rank of `bias` must be exactly 1.
|
||||||
if (getTensorRank(bias) != 1)
|
if (getTensorRank(bias) != 1)
|
||||||
return rewriter.notifyMatchFailure(op, "expected bias to be rank 1");
|
return rewriter.notifyMatchFailure(op, "expected bias to be rank 1");
|
||||||
dim0EqualsNumFeatures(bias);
|
|
||||||
|
|
||||||
bias = rewriter.create<AtenViewOp>(loc, reshapeType, bias,
|
bias = rewriter.create<AtenViewOp>(loc, reshapeType, bias,
|
||||||
runningStatsSizeList);
|
runningStatsSizeList);
|
||||||
batchNormOutput = rewriter.create<AtenAddTensorOp>(
|
batchNormOutput = rewriter.create<AtenAddTensorOp>(
|
||||||
|
@ -1278,8 +1219,6 @@ class DecomposeComplexOpsPass
|
||||||
patterns.add<DecomposeAtenLayerNormOp>(context);
|
patterns.add<DecomposeAtenLayerNormOp>(context);
|
||||||
target.addIllegalOp<AtenNativeBatchNormOp>();
|
target.addIllegalOp<AtenNativeBatchNormOp>();
|
||||||
patterns.add<DecomposeAtenNativeBatchNormOp>(context);
|
patterns.add<DecomposeAtenNativeBatchNormOp>(context);
|
||||||
target.addIllegalOp<AtenBatchNormOp>();
|
|
||||||
patterns.add<DecomposeAtenBatchNormOp>(context);
|
|
||||||
patterns.add<DecomposeAtenArangeOp>(context);
|
patterns.add<DecomposeAtenArangeOp>(context);
|
||||||
target.addIllegalOp<AtenArangeOp>();
|
target.addIllegalOp<AtenArangeOp>();
|
||||||
patterns.add<DecomposeAtenArangeStartOp>(context);
|
patterns.add<DecomposeAtenArangeStartOp>(context);
|
||||||
|
|
|
@ -1867,10 +1867,9 @@ ChangeResult TypeAnalyzer::visitAtenNativeBatchNormOp(
|
||||||
meanKnowledge.dtype = input.dtype;
|
meanKnowledge.dtype = input.dtype;
|
||||||
invStdKnowledge.dtype = input.dtype;
|
invStdKnowledge.dtype = input.dtype;
|
||||||
|
|
||||||
// Rank of the input tensor must be greater than or equal to 2. The shape
|
// Rank of the input tensor must be greater than or equal to 2. The size of
|
||||||
// of the input tensor as well as the batch norm output tensor should be
|
// the input tensor as well as the output tensor should be (N, C, D?, H?, W?).
|
||||||
// (N, C, D?, H?, W?). In inference mode, the mean and inv-std outputs should
|
// The running_mean, running_var, weight, and bias should be of size (C).
|
||||||
// be empty tensors, whereas they should be of shape (C) in the training mode.
|
|
||||||
bool training = false;
|
bool training = false;
|
||||||
if (matchPattern(op.training(), m_TorchConstantBool(&training)) &&
|
if (matchPattern(op.training(), m_TorchConstantBool(&training)) &&
|
||||||
input.hasSizes && input.sizes.size() >= 2) {
|
input.hasSizes && input.sizes.size() >= 2) {
|
||||||
|
|
|
@ -46,14 +46,6 @@ ScalarType getScalarTypeForType(Type type) {
|
||||||
llvm::report_fatal_error("unhandled type for getScalarTypeForType");
|
llvm::report_fatal_error("unhandled type for getScalarTypeForType");
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult checkNotNone(PatternRewriter &rewriter, Operation *op, Value v) {
|
|
||||||
Type type = v.getType();
|
|
||||||
if (type.isa<OptionalType>() || type.isa<Torch::NoneType>() ||
|
|
||||||
type.isa<mlir::NoneType>())
|
|
||||||
return rewriter.notifyMatchFailure(op, "unimplemented None type arg");
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace Torch
|
} // namespace Torch
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -448,49 +448,3 @@ func @torch.aten.hardsigmoid(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor
|
||||||
%0 = torch.aten.hardsigmoid %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
%0 = torch.aten.hardsigmoid %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||||
return %0 : !torch.vtensor<[?,?],f32>
|
return %0 : !torch.vtensor<[?,?],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
|
||||||
// CHECK-LABEL: func @torch.aten.batch_norm(
|
|
||||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?,?],f32>,
|
|
||||||
// CHECK-SAME: %[[WEIGHT:.*]]: !torch.vtensor<[?],f32>, %[[BIAS:.*]]: !torch.vtensor<[?],f32>, %[[RMEAN:.*]]: !torch.vtensor<[?],f32>, %[[RVAR:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
|
||||||
// CHECK: %[[EPS:.*]] = torch.constant.float 1.000000e-05
|
|
||||||
// CHECK: %[[MOM:.*]] = torch.constant.float 1.000000e-01
|
|
||||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
|
||||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
|
||||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
|
||||||
// CHECK: %[[INPUT_DIM1:.*]] = torch.aten.size.int %[[INPUT]], %[[INT1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
|
|
||||||
// CHECK: %[[RMEAN_DIM0:.*]] = torch.aten.size.int %[[RMEAN]], %[[INT0]] : !torch.vtensor<[?],f32>, !torch.int -> !torch.int
|
|
||||||
// CHECK: %[[PRED_MEAN:.*]] = torch.aten.eq.int %[[RMEAN_DIM0]], %[[INPUT_DIM1]] : !torch.int, !torch.int -> !torch.bool
|
|
||||||
// CHECK: torch.runtime.assert %[[PRED_MEAN]], "size of the 0th dimension must be equal to the number of features"
|
|
||||||
// CHECK: %[[RVAR_DIM0:.*]] = torch.aten.size.int %[[RVAR]], %[[INT0]] : !torch.vtensor<[?],f32>, !torch.int -> !torch.int
|
|
||||||
// CHECK: %[[PRED_VAR:.*]] = torch.aten.eq.int %[[RVAR_DIM0]], %[[INPUT_DIM1]] : !torch.int, !torch.int -> !torch.bool
|
|
||||||
// CHECK: torch.runtime.assert %[[PRED_VAR]], "size of the 0th dimension must be equal to the number of features"
|
|
||||||
// CHECK: %[[SIZE_LIST:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INPUT_DIM1]], %[[INT1]], %[[INT1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<!torch.int>
|
|
||||||
// CHECK: %[[RMEAN_VIEW:.*]] = torch.aten.view %[[RMEAN]], %[[SIZE_LIST]] : !torch.vtensor<[?],f32>, !torch.list<!torch.int> -> !torch.vtensor<[1,?,1,1],f32>
|
|
||||||
// CHECK: %[[RVAR_VIEW:.*]] = torch.aten.view %[[RVAR]], %[[SIZE_LIST]] : !torch.vtensor<[?],f32>, !torch.list<!torch.int> -> !torch.vtensor<[1,?,1,1],f32>
|
|
||||||
// CHECK: %[[X_SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[RMEAN_VIEW]], %[[INT1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[1,?,1,1],f32>, !torch.int -> !torch.vtensor<[?,?,?,?],f32>
|
|
||||||
// CHECK: %[[VAR_EPS:.*]] = torch.aten.add.Scalar %[[RVAR_VIEW]], %[[EPS]], %[[INT1]] : !torch.vtensor<[1,?,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,?,1,1],f32>
|
|
||||||
// CHECK: %[[SQRT_VAR_EPS:.*]] = torch.aten.rsqrt %[[VAR_EPS]] : !torch.vtensor<[1,?,1,1],f32> -> !torch.vtensor<[1,?,1,1],f32>
|
|
||||||
// CHECK: %[[NORM_INPUT:.*]] = torch.aten.mul.Tensor %[[X_SUB_MEAN]], %[[SQRT_VAR_EPS]] : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[1,?,1,1],f32> -> !torch.vtensor<[?,?,?,?],f32>
|
|
||||||
// CHECK: %[[WEIGHT_DIM0:.*]] = torch.aten.size.int %[[WEIGHT]], %[[INT0]] : !torch.vtensor<[?],f32>, !torch.int -> !torch.int
|
|
||||||
// CHECK: %[[PRED_WEIGHT:.*]] = torch.aten.eq.int %[[WEIGHT_DIM0]], %[[INPUT_DIM1]] : !torch.int, !torch.int -> !torch.bool
|
|
||||||
// CHECK: torch.runtime.assert %[[PRED_WEIGHT]], "size of the 0th dimension must be equal to the number of features"
|
|
||||||
// CHECK: %[[WEIGHT_VIEW:.*]] = torch.aten.view %[[WEIGHT]], %[[SIZE_LIST]] : !torch.vtensor<[?],f32>, !torch.list<!torch.int> -> !torch.vtensor<[1,?,1,1],f32>
|
|
||||||
// CHECK: %[[SCALED_INPUT:.*]] = torch.aten.mul.Tensor %[[NORM_INPUT]], %[[WEIGHT_VIEW]] : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[1,?,1,1],f32> -> !torch.vtensor<[?,?,?,?],f32>
|
|
||||||
// CHECK: %[[BIAS_DIM0:.*]] = torch.aten.size.int %[[BIAS]], %[[INT0]] : !torch.vtensor<[?],f32>, !torch.int -> !torch.int
|
|
||||||
// CHECK: %[[PRED_BIAS:.*]] = torch.aten.eq.int %[[BIAS_DIM0]], %[[INPUT_DIM1]] : !torch.int, !torch.int -> !torch.bool
|
|
||||||
// CHECK: torch.runtime.assert %[[PRED_BIAS]], "size of the 0th dimension must be equal to the number of features"
|
|
||||||
// CHECK: %[[BIAS_VIEW:.*]] = torch.aten.view %[[BIAS]], %[[SIZE_LIST]] : !torch.vtensor<[?],f32>, !torch.list<!torch.int> -> !torch.vtensor<[1,?,1,1],f32>
|
|
||||||
// CHECK: %[[OUTPUT:.*]] = torch.aten.add.Tensor %[[SCALED_INPUT]], %[[BIAS_VIEW]], %[[INT1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[1,?,1,1],f32>, !torch.int -> !torch.vtensor<[?,?,?,?],f32>
|
|
||||||
// CHECK: %[[ZERO_LIST:.*]] = torch.prim.ListConstruct %[[INT0]] : (!torch.int) -> !torch.list<!torch.int>
|
|
||||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
|
||||||
// CHECK: %[[MEAN_OUT:.*]] = torch.aten.empty.memory_format %[[ZERO_LIST]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[0],f32>
|
|
||||||
// CHECK: %[[INV_STD_OUT:.*]] = torch.aten.empty.memory_format %[[ZERO_LIST]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[0],f32>
|
|
||||||
// CHECK: return %[[OUTPUT]] : !torch.vtensor<[?,?,?,?],f32>
|
|
||||||
func @torch.aten.batch_norm(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?],f32>, %arg2: !torch.vtensor<[?],f32>, %arg3: !torch.vtensor<[?],f32>, %arg4: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
|
||||||
%float1.000000e-05 = torch.constant.float 1.000000e-05
|
|
||||||
%float1.000000e-01 = torch.constant.float 1.000000e-01
|
|
||||||
%false = torch.constant.bool false
|
|
||||||
%0 = torch.aten.batch_norm %arg0, %arg1, %arg2, %arg3, %arg4, %false, %float1.000000e-01, %float1.000000e-05, %false : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
|
|
||||||
return %0 : !torch.vtensor<[?,?,?,?],f32>
|
|
||||||
}
|
|
||||||
|
|
Loading…
Reference in New Issue