torch,linalg: add support for translating aten.linalg.vector_norm (#839)

This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function.  It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.

There exist several opportunities to make this lowering optimal and
robust.  For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf.  For L1 norms, we don't need to raise
each element to the power 1.0.  Similarly, L2 norms could benefit from
strength reduction.  Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
pull/869/head
Ashay Rane 2022-05-19 15:48:15 -07:00 committed by GitHub
parent 3fb54cba4c
commit f18b2be911
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 522 additions and 192 deletions

View File

@ -3625,6 +3625,33 @@ def Torch_AtenBincountOp : Torch_Op<"aten.bincount", [
}]; }];
} }
def Torch_AtenLinalgVectorNormOp : Torch_Op<"aten.linalg_vector_norm", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$ord,
AnyTorchOptionalListOfTorchIntType:$dim,
Torch_BoolType:$keepdim,
AnyTorchOptionalIntType:$dtype
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLinalgVectorNormOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 1);
}
void AtenLinalgVectorNormOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
}
def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [ def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -15,6 +15,7 @@
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
#include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h"
@ -67,7 +68,8 @@ public:
bool keepDim = false; bool keepDim = false;
if (!matchPattern(maxDimOp.keepdim(), m_TorchConstantBool(&keepDim))) if (!matchPattern(maxDimOp.keepdim(), m_TorchConstantBool(&keepDim)))
return failure(); return rewriter.notifyMatchFailure(
maxDimOp, "aten.max_dim requires boolean value for keepdim");
int64_t dim; int64_t dim;
if (!matchPattern(maxDimOp.dim(), m_TorchConstantInt(&dim))) if (!matchPattern(maxDimOp.dim(), m_TorchConstantInt(&dim)))
@ -173,9 +175,8 @@ public:
}; };
} // namespace } // namespace
static Value createLinalgNeutralElementForReduceOp(OpBuilder &b, Location loc, static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
Operation *op, Operation *op, Type elementType) {
Type elementType) {
if (isa<AtenSumOp, AtenSumDimIntListOp>(op)) if (isa<AtenSumOp, AtenSumDimIntListOp>(op))
return b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType)); return b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
@ -195,14 +196,18 @@ static Value createLinalgNeutralElementForReduceOp(OpBuilder &b, Location loc,
elementType.getIntOrFloatBitWidth()))); elementType.getIntOrFloatBitWidth())));
} }
op->emitError("unimplemented lowering in " if (isa<AtenLinalgVectorNormOp>(op))
"createLinalgNeutralElementForReduceOp"); return b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
op->emitError("unimplemented lowering in createInitElementForReduceOp");
return nullptr; return nullptr;
} }
static Value createLinalgPayloadCalculationForReduceOp( static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
OpBuilder &b, Location loc, ValueRange payloadArgs, Operation *op, ValueRange payloadArgs,
ArrayRef<Value> operands, Type resultElementType) { Operation *op,
ArrayRef<Value> operands,
Type resultElementType) {
if (isa<AtenSumOp, AtenSumDimIntListOp>(op)) { if (isa<AtenSumOp, AtenSumDimIntListOp>(op)) {
Value self = Value self =
convertScalarToDtype(b, loc, payloadArgs[0], resultElementType); convertScalarToDtype(b, loc, payloadArgs[0], resultElementType);
@ -228,14 +233,180 @@ static Value createLinalgPayloadCalculationForReduceOp(
if (intType.isSigned()) if (intType.isSigned())
return b.create<arith::MaxSIOp>(loc, self, result); return b.create<arith::MaxSIOp>(loc, self, result);
} }
} else if (isa<AtenLinalgVectorNormOp>(op)) {
// This creates payload for only the first of the two linalg.generic ops.
// TODO: Short-circuit operations if `ord` is zero or one.
Value elem = payloadArgs[0];
Value result = payloadArgs[1];
Value self = convertScalarToDtype(b, loc, elem, resultElementType);
auto abs = b.create<math::AbsOp>(loc, self);
AtenLinalgVectorNormOp::Adaptor adaptor(operands);
Value ord = convertScalarToDtype(b, loc, adaptor.ord(), resultElementType);
auto pow = b.create<math::PowFOp>(loc, abs, ord);
return b.create<arith::AddFOp>(loc, pow, result);
} }
op->emitError("unimplemented lowering in " op->emitError("unimplemented lowering in createLinalgPayloadForReduceOp");
"createLinalgPayloadCalculationForReduceOp");
return nullptr; return nullptr;
} }
namespace { namespace {
class ConvertReductionOp : public ConversionPattern { class ConvertReductionOp : public ConversionPattern {
private:
/// Given a reduction operation that has the `keepdim` attribute and the
/// (optional) `dim` attribute, return the source tensor operand and the
/// literal values of the attributes or failure otherwise.
template <typename T>
FailureOr<torch_to_linalg::ReductionOpInfo>
computeReductionOpInfoForDimVariantOp(
T op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}};
typename T::Adaptor adaptor(operands);
opInfo.tensorOperand = adaptor.self();
auto inputType = opInfo.tensorOperand.getType().cast<RankedTensorType>();
if (!matchPattern(op.keepdim(), m_TorchConstantBool(&opInfo.keepDim)))
return rewriter.notifyMatchFailure(op,
"`keepdim` must be a constant bool");
SmallVector<int64_t> dimList;
if (matchPattern(op.dim(), m_TorchConstantIntList(dimList))) {
// Fix negative dimensions, if any, before adding to the list.
for (int64_t dim : dimList) {
dim = toPositiveDim(dim, inputType.getRank());
// Drop invalid dimensions
if (isValidDim(dim, inputType.getRank()))
opInfo.dimSet.insert(dim);
}
} else if (op.dim().getType().template isa<Torch::NoneType>()) {
// If no dimensions were specified, reduce along all dimensions
for (int64_t i = 0; i < inputType.getRank(); i++)
opInfo.dimSet.insert(i);
} else {
return rewriter.notifyMatchFailure(
op, "`dim` argument must be a constant int list or None");
}
return opInfo;
}
/// Given a reduction operation, return the source tensor operand and the
/// literal values of the `keepdim` and `dim` attributes, if any, or failure
/// otherwise.
FailureOr<torch_to_linalg::ReductionOpInfo>
computeReductionOpInfo(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}};
if (isa<AtenMaxOp, AtenSumOp>(op)) {
opInfo.tensorOperand = operands[0];
auto inputType = opInfo.tensorOperand.getType().cast<RankedTensorType>();
// `AtenSumOp` and `AtenMaxOp` reduces along all the dimensions of the
// input tensor.
for (int64_t i = 0; i < inputType.getRank(); i++)
opInfo.dimSet.insert(i);
return opInfo;
}
if (auto sumOp = dyn_cast<AtenSumDimIntListOp>(op))
return computeReductionOpInfoForDimVariantOp(sumOp, operands, rewriter);
if (auto normOp = dyn_cast<AtenLinalgVectorNormOp>(op))
return computeReductionOpInfoForDimVariantOp(normOp, operands, rewriter);
return rewriter.notifyMatchFailure(op, "not a supported reduce op");
}
/// Generate a linalg.generic operation for pointwise exponentiation of each
/// element.
Value createElementwiseExp(Location loc, Type elemType, Value exponent,
Value inputTensor,
const torch_to_linalg::ReductionOpInfo &opInfo,
ConversionPatternRewriter &rewriter) const {
bool err = false;
auto powBodyBuilder = [&](OpBuilder &builder, Location loc,
ValueRange payloadArgs) {
Value elem = convertScalarToDtype(builder, loc, payloadArgs[0], elemType);
auto result = builder.create<math::PowFOp>(loc, elem, exponent);
if (result)
builder.create<linalg::YieldOp>(loc, Value{result});
err = !result;
};
Value powOp = torch_to_linalg::createElementwiseLinalgGeneric(
rewriter, loc, {inputTensor}, elemType, powBodyBuilder);
return err ? Value{} : powOp;
}
FailureOr<Value> createSecondReductionForVectorNormOp(
Location loc, Type elemType, AtenLinalgVectorNormOp op, Value ordOp,
Value firstReduction, const torch_to_linalg::ReductionOpInfo &opInfo,
ConversionPatternRewriter &rewriter) const {
// Cast `ord` to float so that we can readily pass it math.powf.
Value ordValue = convertScalarToDtype(rewriter, loc, ordOp, elemType);
// TODO: Add support for ord = {0, +inf, -inf}.
auto epsilon = 1e-5;
auto ordLiteral = 0.0;
if (matchPattern(ordValue, m_TorchConstantFloat(&ordLiteral)) &&
fabs(ordLiteral) < epsilon)
return rewriter.notifyMatchFailure(op, "unimplemented: L0 norm");
if (std::isinf(ordLiteral))
return rewriter.notifyMatchFailure(op, "unimplemented: ord = +/- inf");
// Raise each summed value to the inverse of the order of the norm.
Attribute oneAttr = rewriter.getFloatAttr(elemType, 1.0);
auto oneValue = rewriter.create<arith::ConstantOp>(loc, oneAttr);
auto inverseOrdValue =
rewriter.create<arith::DivFOp>(loc, oneValue, ordValue);
// Use the results of the first reduction operation from above to generate
// a second reduction operation.
Value reduceOp = createElementwiseExp(loc, elemType, inverseOrdValue,
firstReduction, opInfo, rewriter);
if (!reduceOp)
return rewriter.notifyMatchFailure(
op, "failed to create linalg.generic operation for element-wise "
"exponentiation");
return reduceOp;
}
/// Generate a linalg.generic operation for a reduction.
Value createReductionOp(Location loc, Type elemType, Operation *op,
ArrayRef<Value> operands,
const torch_to_linalg::ReductionOpInfo &opInfo,
ConversionPatternRewriter &rewriter) const {
bool err = false;
auto reductionBodyBuilder = [&](OpBuilder &builder, Location loc,
ValueRange payloadArgs) {
Value result = createLinalgPayloadForReduceOp(builder, loc, payloadArgs,
op, operands, elemType);
if (result)
builder.create<linalg::YieldOp>(loc, result);
err = !result;
};
Value initElem = createInitElementForReduceOp(rewriter, loc, op, elemType);
Value reduceOp = torch_to_linalg::createReductionLinalgGeneric(
rewriter, loc, opInfo, initElem, reductionBodyBuilder);
return err ? Value{} : reduceOp;
}
/// Depending on the operation, check validity of the result's element type.
LogicalResult
validateReductionElementType(Operation *op, Type elemType,
ConversionPatternRewriter &rewriter) const {
if (isa<AtenLinalgVectorNormOp>(op) && !elemType.isa<mlir::FloatType>())
return rewriter.notifyMatchFailure(
op, "only float types are valid for vector norm ops");
// No checks for all other reduction operations
return success();
}
public: public:
ConvertReductionOp(TypeConverter &typeConverter, MLIRContext *context) ConvertReductionOp(TypeConverter &typeConverter, MLIRContext *context)
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1, : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
@ -244,67 +415,42 @@ public:
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure(); return rewriter.notifyMatchFailure(
op, "invalid operand or result types to use with linalg on tensors");
// Every reduce operation must set a value for the `dimSet`, FailureOr<torch_to_linalg::ReductionOpInfo> opInfo =
// `tensorOperand`, and `keepDim` in accordance with their specification. computeReductionOpInfo(op, operands, rewriter);
DenseSet<int64_t> dimSet; if (failed(opInfo))
Value tensorOperand; return opInfo;
bool keepDim = false;
if (isa<AtenSumOp>(op) || isa<AtenMaxOp>(op)) {
tensorOperand = operands[0];
auto inputType = tensorOperand.getType().cast<RankedTensorType>();
// `AtenSumOp` and `AtenMaxOp` reduces along all the dimensions of the
// input tensor.
for (int64_t i = 0; i < inputType.getRank(); i++)
dimSet.insert(i);
} else if (auto sumDimIntListOp = dyn_cast<AtenSumDimIntListOp>(op)) {
tensorOperand = operands[0];
auto inputType = tensorOperand.getType().cast<RankedTensorType>();
if (!matchPattern(sumDimIntListOp.keepdim(),
m_TorchConstantBool(&keepDim)))
return failure();
SmallVector<int64_t> dimList;
if (!matchPattern(sumDimIntListOp.dim(), m_TorchConstantIntList(dimList)))
return failure();
for (auto dim : dimList) {
// Torch allows for negative values in dimSet to go in reverse
// order in the dimensions of the input tensor.
dim = dim >= 0 ? dim : dim + inputType.getRank();
// Drop invalid dimensions
if (dim < inputType.getRank())
dimSet.insert(dim);
}
} else {
return rewriter.notifyMatchFailure(op, "not a supported reduce op");
}
Location loc = op->getLoc(); Location loc = op->getLoc();
auto resultType = getTypeConverter() auto resultType = getTypeConverter()
->convertType(op->getResult(0).getType()) ->convertType(op->getResult(0).getType())
.cast<RankedTensorType>(); .cast<RankedTensorType>();
Value initElem = createLinalgNeutralElementForReduceOp( Type elemType = resultType.getElementType();
rewriter, loc, op, resultType.getElementType()); LogicalResult elemTypeCheck =
validateReductionElementType(op, elemType, rewriter);
if (failed(elemTypeCheck))
return elemTypeCheck;
bool hadErrorCreatingPayload = false; Value reduceOp =
Value generic = torch_to_linalg::createReductionLinalgGeneric( createReductionOp(loc, elemType, op, operands, *opInfo, rewriter);
rewriter, loc, tensorOperand, dimSet, keepDim, initElem, if (!reduceOp)
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) { return rewriter.notifyMatchFailure(
Value result = createLinalgPayloadCalculationForReduceOp( op, "failed to create linalg.generic operation for reduction");
b, loc, payloadArgs, op, operands, resultType.getElementType());
if (!result) { // If this is aten.linalg_vector_norm op, then we need to generate another
hadErrorCreatingPayload = true; // linalg.generic op that references the first linalg.generic op.
return; if (auto normOp = dyn_cast<AtenLinalgVectorNormOp>(op)) {
AtenLinalgVectorNormOp::Adaptor adaptor(operands);
FailureOr<Value> secondReduceOp = createSecondReductionForVectorNormOp(
loc, elemType, normOp, adaptor.ord(), reduceOp, *opInfo, rewriter);
if (failed(secondReduceOp))
return secondReduceOp;
reduceOp = *secondReduceOp;
} }
b.create<linalg::YieldOp>(loc, result);
});
if (hadErrorCreatingPayload) rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, reduceOp);
return failure();
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, generic);
return success(); return success();
} }
}; };
@ -319,5 +465,6 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality(
target.addIllegalOp<AtenSumOp>(); target.addIllegalOp<AtenSumOp>();
target.addIllegalOp<AtenSumDimIntListOp>(); target.addIllegalOp<AtenSumDimIntListOp>();
target.addIllegalOp<AtenMaxOp>(); target.addIllegalOp<AtenMaxOp>();
target.addIllegalOp<AtenLinalgVectorNormOp>();
patterns.add<ConvertReductionOp>(typeConverter, context); patterns.add<ConvertReductionOp>(typeConverter, context);
} }

View File

@ -35,113 +35,6 @@ template <typename elementType> static bool hasElementType(Value tensor) {
return tensorElementType.isa<elementType>(); return tensorElementType.isa<elementType>();
} }
static Value createElementwiseLinalgGeneric(
OpBuilder &b, Location loc, ValueRange tensorOperands,
Type resultElementType,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
// The overall error handling strategy here is best viewed by thinking about
// what happens for a single result dimension. This loop not structured that
// way because it is hard to create the affine maps for each operand unless
// we structure the loop to iterate over tensor operands as the outer loop
// instead of inner loop. This pseudocode gives better intuition:
// ```
// for each result dimension:
// for each tensor operand:
// if it doesn't even have high enough rank relative to the result:
// continue
// if it is a static size-1 along this result dimension:
// continue
// if this is the first tensor operand that didn't continue above:
// take its dimension size as the size of the non-broadcasted
// traversal along this dimension (this may include a dynamic size-1,
// **non-broadcasted** traversal!)
// emit error check "if the size does not match the non-broadcasted
// traversal size along this dimension, error"
// ```
SmallVector<int64_t> operandRanks;
operandRanks.resize(tensorOperands.size());
llvm::transform(tensorOperands, operandRanks.begin(), [](Value tensor) {
return tensor.getType().dyn_cast<RankedTensorType>().getRank();
});
auto resultRankIt =
std::max_element(operandRanks.begin(), operandRanks.end());
assert(resultRankIt != operandRanks.end() && "Unable to get result rank.");
int64_t resultRank = *resultRankIt;
// Initialize the resultShape to all 1's, as a fallback in case
// all sizes along that result dimension are statically 1.
auto c1 = b.create<arith::ConstantIndexOp>(loc, /*value=*/1);
SmallVector<Value> resultShape(resultRank, c1);
SmallVector<AffineMap> indexingMaps;
for (Value tensorOperand : tensorOperands) {
SmallVector<AffineExpr> exprs;
auto type = tensorOperand.getType().cast<RankedTensorType>();
for (auto size : llvm::enumerate(type.getShape())) {
// If the size is statically known to be 1, we don't want any
// error guards to be spuriously emitted, since we are specifically
// allowing size-1 broadcasts in this case, as they correspond to a
// constant-0 indexing map.
if (size.value() == 1) {
exprs.push_back(b.getAffineConstantExpr(0));
continue;
}
// The rank of this operand might be smaller than the overall rank of
// the broadcast. Add an offset to correlate it to the correct
// dimension of the result.
auto resultDim = size.index() + (resultRank - type.getRank());
// The generated linalg op will now be iterating along the full size
// of this dimension. Record that fact.
exprs.push_back(b.getAffineDimExpr(resultDim));
// Now, we need to ensure that such iteration is not going to trigger
// undefined behavior, by doing appropriate checks against the current
// dimension size.
auto currentDimSize = getDimOp(b, loc, tensorOperand, size.index());
// If the result size of this dimension has so far only hit the
// statically-known-to-be-1 case above (i.e., we have not yet assigned a
// new Value to `resultShape[resultDim]`), then we have no other dynamic
// values to check against, and merely need to record the current
// dimension size.
if (resultShape[resultDim] == c1) {
resultShape[resultDim] = currentDimSize;
continue;
}
// We prohibit the size-1 dynamic broadcasting scenario, so just check
// for exact equality with the running result size.
// This is the check which protects against the undefined behavior of
// the generated linalg op in the case of iterating two operands with
// dimensions sizes that are expected to match.
auto equalToRunning =
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
resultShape[resultDim], currentDimSize);
b.create<cf::AssertOp>(loc, equalToRunning,
"mismatched size for broadcast");
}
indexingMaps.push_back(AffineMap::get(
/*dimCount=*/resultRank, /*symbolCount=*/0, exprs, b.getContext()));
}
SmallVector<StringRef> iteratorTypes(resultRank,
getParallelIteratorTypeName());
// Add the indexing map for the outs init tensor.
indexingMaps.push_back(b.getMultiDimIdentityMap(resultRank));
Value initTensor = b.create<linalg::InitTensorOp>(
loc, getAsOpFoldResult(resultShape), resultElementType);
return b
.create<linalg::GenericOp>(loc,
/*resultTensorTypes=*/initTensor.getType(),
/*inputs=*/tensorOperands,
/*outputs=*/initTensor, indexingMaps,
iteratorTypes, bodyBuild)
.getResult(0);
}
template <arith::CmpFPredicate fpred, arith::CmpIPredicate iupred, template <arith::CmpFPredicate fpred, arith::CmpIPredicate iupred,
arith::CmpIPredicate ispred> arith::CmpIPredicate ispred>
static Value createComparisonTemplate(OpBuilder &b, Location loc, Type type, static Value createComparisonTemplate(OpBuilder &b, Location loc, Type type,
@ -964,7 +857,7 @@ public:
->convertType(op->getResult(0).getType()) ->convertType(op->getResult(0).getType())
.cast<RankedTensorType>(); .cast<RankedTensorType>();
bool hadErrorCreatingPayload = false; bool hadErrorCreatingPayload = false;
Value generic = createElementwiseLinalgGeneric( Value generic = torch_to_linalg::createElementwiseLinalgGeneric(
rewriter, loc, tensorOperands, resultType.getElementType(), rewriter, loc, tensorOperands, resultType.getElementType(),
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) { [&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
Value result = createLinalgPayloadCalculationForElementwiseOp( Value result = createLinalgPayloadCalculationForElementwiseOp(
@ -1031,7 +924,7 @@ public:
Value zeroVal = rewriter.create<arith::ConstantOp>( Value zeroVal = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(elementType)); loc, rewriter.getZeroAttr(elementType));
Value finalRes = createElementwiseLinalgGeneric( Value finalRes = torch_to_linalg::createElementwiseLinalgGeneric(
rewriter, loc, {target}, elementType, rewriter, loc, {target}, elementType,
[&](OpBuilder &b, Location loc, ValueRange args) { [&](OpBuilder &b, Location loc, ValueRange args) {
Value targetVal = args[0]; Value targetVal = args[0];
@ -1068,8 +961,9 @@ public:
/*inclusive=*/false); /*inclusive=*/false);
DenseSet<int64_t> dimSet(dimsToReduce.begin(), dimsToReduce.end()); DenseSet<int64_t> dimSet(dimsToReduce.begin(), dimsToReduce.end());
auto opInfo = torch_to_linalg::ReductionOpInfo{false, finalRes, dimSet};
finalRes = torch_to_linalg::createReductionLinalgGeneric( finalRes = torch_to_linalg::createReductionLinalgGeneric(
rewriter, loc, finalRes, dimSet, /*keepDim=*/false, rewriter, loc, opInfo,
/*initElem=*/zeroVal, /*initElem=*/zeroVal,
[&](OpBuilder &b, Location loc, ValueRange args) { [&](OpBuilder &b, Location loc, ValueRange args) {
Value newVal = args[0]; Value newVal = args[0];

View File

@ -98,23 +98,22 @@ Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc,
} }
Value torch_to_linalg::createReductionLinalgGeneric( Value torch_to_linalg::createReductionLinalgGeneric(
OpBuilder &b, Location loc, Value tensorOperand, OpBuilder &b, Location loc, const ReductionOpInfo &opInfo, Value initElem,
const DenseSet<int64_t> &dimSet, bool keepDim, Value initElem,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) { function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
auto inputType = tensorOperand.getType().cast<RankedTensorType>(); auto inputType = opInfo.tensorOperand.getType().cast<RankedTensorType>();
// Get the result shape by obtaining the size of each // Get the result shape by obtaining the size of each
// dimension in the input tensor that is not getting reduced. // dimension in the input tensor that is not getting reduced.
// If `keepDim` is true, the rank of the output tensor // If `opInfo.keepDim` is true, the rank of the output tensor
// is kept the same as the rank of the input tensor, and the // is kept the same as the rank of the input tensor, and the
// reduced dimensions are set to have size 1. // reduced dimensions are set to have size 1.
auto c1 = b.create<arith::ConstantIndexOp>(loc, /*value=*/1); auto c1 = b.create<arith::ConstantIndexOp>(loc, /*value=*/1);
SmallVector<Value> resultShape; SmallVector<Value> resultShape;
for (int64_t i = 0; i < inputType.getRank(); i++) { for (int64_t i = 0; i < inputType.getRank(); i++) {
auto currentDimSize = b.create<tensor::DimOp>(loc, tensorOperand, i); auto currentDimSize = b.create<tensor::DimOp>(loc, opInfo.tensorOperand, i);
if (!dimSet.contains(i)) if (!opInfo.dimSet.contains(i))
resultShape.push_back(currentDimSize); resultShape.push_back(currentDimSize);
else if (keepDim) else if (opInfo.keepDim)
resultShape.push_back(c1); resultShape.push_back(c1);
} }
@ -127,11 +126,11 @@ Value torch_to_linalg::createReductionLinalgGeneric(
for (auto size : llvm::enumerate(inputType.getShape())) { for (auto size : llvm::enumerate(inputType.getShape())) {
exprs.push_back(b.getAffineDimExpr(size.index())); exprs.push_back(b.getAffineDimExpr(size.index()));
if (dimSet.contains(size.index())) { if (opInfo.dimSet.contains(size.index())) {
iteratorTypes.push_back(getReductionIteratorTypeName()); iteratorTypes.push_back(getReductionIteratorTypeName());
// If `keepDim`, create affine map to the first element // If `opInfo.keepDim`, create affine map to the first element
// in the current dimension. // in the current dimension.
if (keepDim) if (opInfo.keepDim)
resultExprs.push_back(b.getAffineConstantExpr(0)); resultExprs.push_back(b.getAffineConstantExpr(0));
} else { } else {
iteratorTypes.push_back(getParallelIteratorTypeName()); iteratorTypes.push_back(getParallelIteratorTypeName());
@ -146,7 +145,114 @@ Value torch_to_linalg::createReductionLinalgGeneric(
return b return b
.create<linalg::GenericOp>( .create<linalg::GenericOp>(
loc, /*resultTensorTypes=*/accumulator.getType(), loc, /*resultTensorTypes=*/accumulator.getType(),
/*inputs=*/tensorOperand, /*inputs=*/opInfo.tensorOperand,
/*outputs=*/accumulator, indexingMaps, iteratorTypes, bodyBuild) /*outputs=*/accumulator, indexingMaps, iteratorTypes, bodyBuild)
.getResult(0); .getResult(0);
} }
Value torch_to_linalg::createElementwiseLinalgGeneric(
OpBuilder &b, Location loc, ValueRange tensorOperands,
Type resultElementType,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
// The overall error handling strategy here is best viewed by thinking about
// what happens for a single result dimension. This loop not structured that
// way because it is hard to create the affine maps for each operand unless
// we structure the loop to iterate over tensor operands as the outer loop
// instead of inner loop. This pseudocode gives better intuition:
// ```
// for each result dimension:
// for each tensor operand:
// if it doesn't even have high enough rank relative to the result:
// continue
// if it is a static size-1 along this result dimension:
// continue
// if this is the first tensor operand that didn't continue above:
// take its dimension size as the size of the non-broadcasted
// traversal along this dimension (this may include a dynamic size-1,
// **non-broadcasted** traversal!)
// emit error check "if the size does not match the non-broadcasted
// traversal size along this dimension, error"
// ```
SmallVector<int64_t> operandRanks;
operandRanks.resize(tensorOperands.size());
llvm::transform(tensorOperands, operandRanks.begin(), [](Value tensor) {
return tensor.getType().dyn_cast<RankedTensorType>().getRank();
});
auto resultRankIt =
std::max_element(operandRanks.begin(), operandRanks.end());
assert(resultRankIt != operandRanks.end() && "Unable to get result rank.");
int64_t resultRank = *resultRankIt;
// Initialize the resultShape to all 1's, as a fallback in case
// all sizes along that result dimension are statically 1.
auto c1 = b.create<arith::ConstantIndexOp>(loc, /*value=*/1);
SmallVector<Value> resultShape(resultRank, c1);
SmallVector<AffineMap> indexingMaps;
for (Value tensorOperand : tensorOperands) {
SmallVector<AffineExpr> exprs;
auto type = tensorOperand.getType().cast<RankedTensorType>();
for (auto size : llvm::enumerate(type.getShape())) {
// If the size is statically known to be 1, we don't want any
// error guards to be spuriously emitted, since we are specifically
// allowing size-1 broadcasts in this case, as they correspond to a
// constant-0 indexing map.
if (size.value() == 1) {
exprs.push_back(b.getAffineConstantExpr(0));
continue;
}
// The rank of this operand might be smaller than the overall rank of
// the broadcast. Add an offset to correlate it to the correct
// dimension of the result.
auto resultDim = size.index() + (resultRank - type.getRank());
// The generated linalg op will now be iterating along the full size
// of this dimension. Record that fact.
exprs.push_back(b.getAffineDimExpr(resultDim));
// Now, we need to ensure that such iteration is not going to trigger
// undefined behavior, by doing appropriate checks against the current
// dimension size.
auto currentDimSize = getDimOp(b, loc, tensorOperand, size.index());
// If the result size of this dimension has so far only hit the
// statically-known-to-be-1 case above (i.e., we have not yet assigned a
// new Value to `resultShape[resultDim]`), then we have no other dynamic
// values to check against, and merely need to record the current
// dimension size.
if (resultShape[resultDim] == c1) {
resultShape[resultDim] = currentDimSize;
continue;
}
// We prohibit the size-1 dynamic broadcasting scenario, so just check
// for exact equality with the running result size.
// This is the check which protects against the undefined behavior of
// the generated linalg op in the case of iterating two operands with
// dimensions sizes that are expected to match.
auto equalToRunning =
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
resultShape[resultDim], currentDimSize);
b.create<cf::AssertOp>(loc, equalToRunning,
"mismatched size for broadcast");
}
indexingMaps.push_back(AffineMap::get(
/*dimCount=*/resultRank, /*symbolCount=*/0, exprs, b.getContext()));
}
SmallVector<StringRef> iteratorTypes(resultRank,
getParallelIteratorTypeName());
// Add the indexing map for the outs init tensor.
indexingMaps.push_back(b.getMultiDimIdentityMap(resultRank));
Value initTensor = b.create<linalg::InitTensorOp>(
loc, getAsOpFoldResult(resultShape), resultElementType);
return b
.create<linalg::GenericOp>(loc,
/*resultTensorTypes=*/initTensor.getType(),
/*inputs=*/tensorOperands,
/*outputs=*/initTensor, indexingMaps,
iteratorTypes, bodyBuild)
.getResult(0);
}

View File

@ -13,6 +13,12 @@ namespace mlir {
namespace torch { namespace torch {
namespace torch_to_linalg { namespace torch_to_linalg {
struct ReductionOpInfo {
bool keepDim;
Value tensorOperand;
DenseSet<int64_t> dimSet;
};
// Helper function to get the padding tensor given the padding int values. // Helper function to get the padding tensor given the padding int values.
Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input, Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input,
SmallVectorImpl<int64_t> &lowPaddingInts, SmallVectorImpl<int64_t> &lowPaddingInts,
@ -33,16 +39,21 @@ Value getOutputDimForConvOps(OpBuilder &b, Location loc, Value in,
Value kernelSizeInt, Value strideInt, Value kernelSizeInt, Value strideInt,
bool ceilMode = false); bool ceilMode = false);
// Create a reduction of `tensorOperand`, reducing along the dimensions // Create a reduction of `opInfo.tensorOperand`, reducing along the dimensions
// in `dimSet`. If `keepDim` is true, the output tensor is the same // in `opInfo.dimSet`. If `opInfo.keepDim` is true, the output tensor is the
// rank as the `tensorOperand` and reduced dimensions are set to size 1. // same rank as the `opInfo.tensorOperand` and reduced dimensions are set to
// `initElem` is the element used to initialize the output tensor // size 1. `initElem` is the element used to initialize the output tensor where
// where the reduction will be stored. // the reduction will be stored.
Value createReductionLinalgGeneric( Value createReductionLinalgGeneric(
OpBuilder &b, Location loc, Value tensorOperand, OpBuilder &b, Location loc, const ReductionOpInfo &opInfo, Value initElem,
const DenseSet<int64_t> &dimSet, bool keepDim, Value initElem,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild); function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild);
// Create a pointwise operation that uses values in `tensorOperands`, such that
// the element type of the resulting tensor is `resultElementType`.
Value createElementwiseLinalgGeneric(
OpBuilder &b, Location loc, ValueRange tensorOperands,
Type resultElementType,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild);
} // namespace torch_to_linalg } // namespace torch_to_linalg
} // namespace torch } // namespace torch
} // namespace mlir } // namespace mlir

View File

@ -988,6 +988,14 @@ ChangeResult TypeAnalyzer::visitOperation(
if (auto scalarImplicit = dyn_cast<AtenScalarImplicitOp>(op)) if (auto scalarImplicit = dyn_cast<AtenScalarImplicitOp>(op))
return visitAtenScalarImplicitOp(scalarImplicit, operands); return visitAtenScalarImplicitOp(scalarImplicit, operands);
if (auto vectorNorm = dyn_cast<AtenLinalgVectorNormOp>(op)) {
Type defaultDtype = operands[0]->getValue().dtype;
Type dtype = getDtypeOrDefault(vectorNorm.getContext(), vectorNorm.dtype(),
defaultDtype);
return visitReductionAlongDimIntListOp(
vectorNorm, vectorNorm.dim(), vectorNorm.keepdim(), dtype, operands);
}
// Otherwise, this is an unknown operation. Just mark all results as // Otherwise, this is an unknown operation. Just mark all results as
// having reached a pessimistic fixpoint. // having reached a pessimistic fixpoint.
return markAllPessimisticFixpoint(op->getResults()); return markAllPessimisticFixpoint(op->getResults());

View File

@ -3078,6 +3078,29 @@ module {
%none = torch.constant.none %none = torch.constant.none
return %none : !torch.none return %none : !torch.none
} }
func.func @"__torch_mlir_shape_fn.aten.linalg_vector_norm"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.optional<list<int>>, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.list<int> {
%none = torch.constant.none
%true = torch.constant.bool true
%0 = torch.aten.__isnot__ %arg2, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
%1 = torch.prim.If %0 -> (!torch.list<int>) {
%2 = torch.prim.unchecked_cast %arg2 : !torch.optional<list<int>> -> !torch.list<int>
%3 = torch.derefine %arg4 : !torch.optional<int> to !torch.any
%4 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.mean_dim(%arg0, %2, %arg3, %3) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
torch.prim.If.yield %4 : !torch.list<int>
} else {
%2 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%3 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %2, %true, init() {
^bb0(%arg5: !torch.int):
%6 = torch.aten.append.t %3, %arg5 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%4 = torch.derefine %arg4 : !torch.optional<int> to !torch.any
%5 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.mean_dim(%arg0, %3, %arg3, %4) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
torch.prim.If.yield %5 : !torch.list<int>
}
return %1 : !torch.list<int>
}
} }
)mlir"); )mlir");
#pragma clang diagnostic pop #pragma clang diagnostic pop

View File

@ -946,6 +946,11 @@ def hacky_get_unknown_dimension_size():
def atenbincount(self: List[int], weights: Optional[List[int]] = None, minlength: int = 0) -> List[int]: def atenbincount(self: List[int], weights: Optional[List[int]] = None, minlength: int = 0) -> List[int]:
return [hacky_get_unknown_dimension_size()] return [hacky_get_unknown_dimension_size()]
def atenlinalg_vector_norm(self: List[int], ord: float = 2, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
if dim is None:
dim = list(range(len(self)))
return upstream_shape_helpers.mean_dim(self, dim, keepdim, dtype)
# ============================================================================== # ==============================================================================
# Shape library generator main(). # Shape library generator main().
# ============================================================================== # ==============================================================================

View File

@ -365,6 +365,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)") emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)")
emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)") emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)")
emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)") emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)")
emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)")
# Misc tensor ops. # Misc tensor ops.
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)") emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")

View File

@ -383,3 +383,111 @@ class ReduceMaxUnsignedIntModule(torch.nn.Module):
@register_test_case(module_factory=lambda: ReduceMaxUnsignedIntModule()) @register_test_case(module_factory=lambda: ReduceMaxUnsignedIntModule())
def ReduceMaxUnsignedIntModule_basic(module, tu: TestUtils): def ReduceMaxUnsignedIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(100, (3, 4, 5))) module.forward(torch.randint(100, (3, 4, 5)))
# ==============================================================================
class ReduceL1NormModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.linalg.vector_norm(a, dim=0, ord=1)
@register_test_case(module_factory=lambda: ReduceL1NormModule())
def ReduceL1NormModule_basic(module, tu: TestUtils):
module.forward(torch.rand(3, 4, 5))
# ==============================================================================
class ReduceL1NormWithDTypeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.linalg.vector_norm(a, dim=0, ord=1, dtype=torch.float64)
@register_test_case(module_factory=lambda: ReduceL1NormWithDTypeModule())
def ReduceL1NormWithDTypeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5).to(torch.float32))
# ==============================================================================
class ReduceL2NormModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.linalg.vector_norm(a, dim=0)
@register_test_case(module_factory=lambda: ReduceL2NormModule())
def ReduceL2NormModule_basic(module, tu: TestUtils):
module.forward(torch.rand(3, 4, 5))
# ==============================================================================
class ReduceLN3NormModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.linalg.vector_norm(a, dim=0, ord=-3)
@register_test_case(module_factory=lambda: ReduceLN3NormModule())
def ReduceLN3NormModule_basic(module, tu: TestUtils):
module.forward(torch.rand(3, 4, 5))
# ==============================================================================
class ReduceL3NormAllDimsModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.linalg.vector_norm(a, dim=None, ord=3)
@register_test_case(module_factory=lambda: ReduceL3NormAllDimsModule())
def ReduceL3NormAllDimsModule_basic(module, tu: TestUtils):
module.forward(torch.rand(3, 4, 5))
# ==============================================================================
class ReduceL3NormKeepDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.linalg.vector_norm(a, keepdim=True, ord=3)
@register_test_case(module_factory=lambda: ReduceL3NormKeepDimModule())
def ReduceL3NormKeepDimModule_basic(module, tu: TestUtils):
module.forward(torch.rand(3, 4, 5))