mirror of https://github.com/llvm/torch-mlir
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch dialect, including the necessary shape function. It also extends the conversion of reduction operators to support lowering of AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end tests to validate the lowering. There exist several opportunities to make this lowering optimal and robust. For instance, in its current form, the translation does not support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise each element to the power 1.0. Similarly, L2 norms could benefit from strength reduction. Since the canonicalization pass is not able to apply these optimizations, we should consider applying them during the linalg lowering itself.pull/869/head
parent
3fb54cba4c
commit
f18b2be911
|
@ -3625,6 +3625,33 @@ def Torch_AtenBincountOp : Torch_Op<"aten.bincount", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenLinalgVectorNormOp : Torch_Op<"aten.linalg_vector_norm", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchScalarType:$ord,
|
||||
AnyTorchOptionalListOfTorchIntType:$dim,
|
||||
Torch_BoolType:$keepdim,
|
||||
AnyTorchOptionalIntType:$dtype
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenLinalgVectorNormOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 5, 1);
|
||||
}
|
||||
void AtenLinalgVectorNormOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 5, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
|
@ -67,7 +68,8 @@ public:
|
|||
|
||||
bool keepDim = false;
|
||||
if (!matchPattern(maxDimOp.keepdim(), m_TorchConstantBool(&keepDim)))
|
||||
return failure();
|
||||
return rewriter.notifyMatchFailure(
|
||||
maxDimOp, "aten.max_dim requires boolean value for keepdim");
|
||||
|
||||
int64_t dim;
|
||||
if (!matchPattern(maxDimOp.dim(), m_TorchConstantInt(&dim)))
|
||||
|
@ -173,9 +175,8 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
static Value createLinalgNeutralElementForReduceOp(OpBuilder &b, Location loc,
|
||||
Operation *op,
|
||||
Type elementType) {
|
||||
static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
|
||||
Operation *op, Type elementType) {
|
||||
if (isa<AtenSumOp, AtenSumDimIntListOp>(op))
|
||||
return b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
|
||||
|
||||
|
@ -195,14 +196,18 @@ static Value createLinalgNeutralElementForReduceOp(OpBuilder &b, Location loc,
|
|||
elementType.getIntOrFloatBitWidth())));
|
||||
}
|
||||
|
||||
op->emitError("unimplemented lowering in "
|
||||
"createLinalgNeutralElementForReduceOp");
|
||||
if (isa<AtenLinalgVectorNormOp>(op))
|
||||
return b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
|
||||
|
||||
op->emitError("unimplemented lowering in createInitElementForReduceOp");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static Value createLinalgPayloadCalculationForReduceOp(
|
||||
OpBuilder &b, Location loc, ValueRange payloadArgs, Operation *op,
|
||||
ArrayRef<Value> operands, Type resultElementType) {
|
||||
static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
|
||||
ValueRange payloadArgs,
|
||||
Operation *op,
|
||||
ArrayRef<Value> operands,
|
||||
Type resultElementType) {
|
||||
if (isa<AtenSumOp, AtenSumDimIntListOp>(op)) {
|
||||
Value self =
|
||||
convertScalarToDtype(b, loc, payloadArgs[0], resultElementType);
|
||||
|
@ -228,14 +233,180 @@ static Value createLinalgPayloadCalculationForReduceOp(
|
|||
if (intType.isSigned())
|
||||
return b.create<arith::MaxSIOp>(loc, self, result);
|
||||
}
|
||||
} else if (isa<AtenLinalgVectorNormOp>(op)) {
|
||||
// This creates payload for only the first of the two linalg.generic ops.
|
||||
// TODO: Short-circuit operations if `ord` is zero or one.
|
||||
Value elem = payloadArgs[0];
|
||||
Value result = payloadArgs[1];
|
||||
Value self = convertScalarToDtype(b, loc, elem, resultElementType);
|
||||
auto abs = b.create<math::AbsOp>(loc, self);
|
||||
AtenLinalgVectorNormOp::Adaptor adaptor(operands);
|
||||
Value ord = convertScalarToDtype(b, loc, adaptor.ord(), resultElementType);
|
||||
auto pow = b.create<math::PowFOp>(loc, abs, ord);
|
||||
return b.create<arith::AddFOp>(loc, pow, result);
|
||||
}
|
||||
op->emitError("unimplemented lowering in "
|
||||
"createLinalgPayloadCalculationForReduceOp");
|
||||
op->emitError("unimplemented lowering in createLinalgPayloadForReduceOp");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
namespace {
|
||||
class ConvertReductionOp : public ConversionPattern {
|
||||
private:
|
||||
/// Given a reduction operation that has the `keepdim` attribute and the
|
||||
/// (optional) `dim` attribute, return the source tensor operand and the
|
||||
/// literal values of the attributes or failure otherwise.
|
||||
template <typename T>
|
||||
FailureOr<torch_to_linalg::ReductionOpInfo>
|
||||
computeReductionOpInfoForDimVariantOp(
|
||||
T op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}};
|
||||
typename T::Adaptor adaptor(operands);
|
||||
opInfo.tensorOperand = adaptor.self();
|
||||
auto inputType = opInfo.tensorOperand.getType().cast<RankedTensorType>();
|
||||
|
||||
if (!matchPattern(op.keepdim(), m_TorchConstantBool(&opInfo.keepDim)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"`keepdim` must be a constant bool");
|
||||
|
||||
SmallVector<int64_t> dimList;
|
||||
if (matchPattern(op.dim(), m_TorchConstantIntList(dimList))) {
|
||||
// Fix negative dimensions, if any, before adding to the list.
|
||||
for (int64_t dim : dimList) {
|
||||
dim = toPositiveDim(dim, inputType.getRank());
|
||||
// Drop invalid dimensions
|
||||
if (isValidDim(dim, inputType.getRank()))
|
||||
opInfo.dimSet.insert(dim);
|
||||
}
|
||||
} else if (op.dim().getType().template isa<Torch::NoneType>()) {
|
||||
// If no dimensions were specified, reduce along all dimensions
|
||||
for (int64_t i = 0; i < inputType.getRank(); i++)
|
||||
opInfo.dimSet.insert(i);
|
||||
} else {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "`dim` argument must be a constant int list or None");
|
||||
}
|
||||
|
||||
return opInfo;
|
||||
}
|
||||
|
||||
/// Given a reduction operation, return the source tensor operand and the
|
||||
/// literal values of the `keepdim` and `dim` attributes, if any, or failure
|
||||
/// otherwise.
|
||||
FailureOr<torch_to_linalg::ReductionOpInfo>
|
||||
computeReductionOpInfo(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}};
|
||||
|
||||
if (isa<AtenMaxOp, AtenSumOp>(op)) {
|
||||
opInfo.tensorOperand = operands[0];
|
||||
auto inputType = opInfo.tensorOperand.getType().cast<RankedTensorType>();
|
||||
|
||||
// `AtenSumOp` and `AtenMaxOp` reduces along all the dimensions of the
|
||||
// input tensor.
|
||||
for (int64_t i = 0; i < inputType.getRank(); i++)
|
||||
opInfo.dimSet.insert(i);
|
||||
|
||||
return opInfo;
|
||||
}
|
||||
|
||||
if (auto sumOp = dyn_cast<AtenSumDimIntListOp>(op))
|
||||
return computeReductionOpInfoForDimVariantOp(sumOp, operands, rewriter);
|
||||
|
||||
if (auto normOp = dyn_cast<AtenLinalgVectorNormOp>(op))
|
||||
return computeReductionOpInfoForDimVariantOp(normOp, operands, rewriter);
|
||||
|
||||
return rewriter.notifyMatchFailure(op, "not a supported reduce op");
|
||||
}
|
||||
|
||||
/// Generate a linalg.generic operation for pointwise exponentiation of each
|
||||
/// element.
|
||||
Value createElementwiseExp(Location loc, Type elemType, Value exponent,
|
||||
Value inputTensor,
|
||||
const torch_to_linalg::ReductionOpInfo &opInfo,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
bool err = false;
|
||||
auto powBodyBuilder = [&](OpBuilder &builder, Location loc,
|
||||
ValueRange payloadArgs) {
|
||||
Value elem = convertScalarToDtype(builder, loc, payloadArgs[0], elemType);
|
||||
auto result = builder.create<math::PowFOp>(loc, elem, exponent);
|
||||
if (result)
|
||||
builder.create<linalg::YieldOp>(loc, Value{result});
|
||||
err = !result;
|
||||
};
|
||||
|
||||
Value powOp = torch_to_linalg::createElementwiseLinalgGeneric(
|
||||
rewriter, loc, {inputTensor}, elemType, powBodyBuilder);
|
||||
return err ? Value{} : powOp;
|
||||
}
|
||||
|
||||
FailureOr<Value> createSecondReductionForVectorNormOp(
|
||||
Location loc, Type elemType, AtenLinalgVectorNormOp op, Value ordOp,
|
||||
Value firstReduction, const torch_to_linalg::ReductionOpInfo &opInfo,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// Cast `ord` to float so that we can readily pass it math.powf.
|
||||
Value ordValue = convertScalarToDtype(rewriter, loc, ordOp, elemType);
|
||||
|
||||
// TODO: Add support for ord = {0, +inf, -inf}.
|
||||
auto epsilon = 1e-5;
|
||||
auto ordLiteral = 0.0;
|
||||
if (matchPattern(ordValue, m_TorchConstantFloat(&ordLiteral)) &&
|
||||
fabs(ordLiteral) < epsilon)
|
||||
return rewriter.notifyMatchFailure(op, "unimplemented: L0 norm");
|
||||
|
||||
if (std::isinf(ordLiteral))
|
||||
return rewriter.notifyMatchFailure(op, "unimplemented: ord = +/- inf");
|
||||
|
||||
// Raise each summed value to the inverse of the order of the norm.
|
||||
Attribute oneAttr = rewriter.getFloatAttr(elemType, 1.0);
|
||||
auto oneValue = rewriter.create<arith::ConstantOp>(loc, oneAttr);
|
||||
auto inverseOrdValue =
|
||||
rewriter.create<arith::DivFOp>(loc, oneValue, ordValue);
|
||||
|
||||
// Use the results of the first reduction operation from above to generate
|
||||
// a second reduction operation.
|
||||
Value reduceOp = createElementwiseExp(loc, elemType, inverseOrdValue,
|
||||
firstReduction, opInfo, rewriter);
|
||||
if (!reduceOp)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to create linalg.generic operation for element-wise "
|
||||
"exponentiation");
|
||||
|
||||
return reduceOp;
|
||||
}
|
||||
|
||||
/// Generate a linalg.generic operation for a reduction.
|
||||
Value createReductionOp(Location loc, Type elemType, Operation *op,
|
||||
ArrayRef<Value> operands,
|
||||
const torch_to_linalg::ReductionOpInfo &opInfo,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
bool err = false;
|
||||
auto reductionBodyBuilder = [&](OpBuilder &builder, Location loc,
|
||||
ValueRange payloadArgs) {
|
||||
Value result = createLinalgPayloadForReduceOp(builder, loc, payloadArgs,
|
||||
op, operands, elemType);
|
||||
if (result)
|
||||
builder.create<linalg::YieldOp>(loc, result);
|
||||
err = !result;
|
||||
};
|
||||
|
||||
Value initElem = createInitElementForReduceOp(rewriter, loc, op, elemType);
|
||||
Value reduceOp = torch_to_linalg::createReductionLinalgGeneric(
|
||||
rewriter, loc, opInfo, initElem, reductionBodyBuilder);
|
||||
return err ? Value{} : reduceOp;
|
||||
}
|
||||
|
||||
/// Depending on the operation, check validity of the result's element type.
|
||||
LogicalResult
|
||||
validateReductionElementType(Operation *op, Type elemType,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
if (isa<AtenLinalgVectorNormOp>(op) && !elemType.isa<mlir::FloatType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only float types are valid for vector norm ops");
|
||||
// No checks for all other reduction operations
|
||||
return success();
|
||||
}
|
||||
|
||||
public:
|
||||
ConvertReductionOp(TypeConverter &typeConverter, MLIRContext *context)
|
||||
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
|
||||
|
@ -244,67 +415,42 @@ public:
|
|||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "invalid operand or result types to use with linalg on tensors");
|
||||
|
||||
// Every reduce operation must set a value for the `dimSet`,
|
||||
// `tensorOperand`, and `keepDim` in accordance with their specification.
|
||||
DenseSet<int64_t> dimSet;
|
||||
Value tensorOperand;
|
||||
bool keepDim = false;
|
||||
if (isa<AtenSumOp>(op) || isa<AtenMaxOp>(op)) {
|
||||
tensorOperand = operands[0];
|
||||
auto inputType = tensorOperand.getType().cast<RankedTensorType>();
|
||||
|
||||
// `AtenSumOp` and `AtenMaxOp` reduces along all the dimensions of the
|
||||
// input tensor.
|
||||
for (int64_t i = 0; i < inputType.getRank(); i++)
|
||||
dimSet.insert(i);
|
||||
} else if (auto sumDimIntListOp = dyn_cast<AtenSumDimIntListOp>(op)) {
|
||||
tensorOperand = operands[0];
|
||||
auto inputType = tensorOperand.getType().cast<RankedTensorType>();
|
||||
|
||||
if (!matchPattern(sumDimIntListOp.keepdim(),
|
||||
m_TorchConstantBool(&keepDim)))
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> dimList;
|
||||
if (!matchPattern(sumDimIntListOp.dim(), m_TorchConstantIntList(dimList)))
|
||||
return failure();
|
||||
for (auto dim : dimList) {
|
||||
// Torch allows for negative values in dimSet to go in reverse
|
||||
// order in the dimensions of the input tensor.
|
||||
dim = dim >= 0 ? dim : dim + inputType.getRank();
|
||||
// Drop invalid dimensions
|
||||
if (dim < inputType.getRank())
|
||||
dimSet.insert(dim);
|
||||
}
|
||||
} else {
|
||||
return rewriter.notifyMatchFailure(op, "not a supported reduce op");
|
||||
}
|
||||
FailureOr<torch_to_linalg::ReductionOpInfo> opInfo =
|
||||
computeReductionOpInfo(op, operands, rewriter);
|
||||
if (failed(opInfo))
|
||||
return opInfo;
|
||||
|
||||
Location loc = op->getLoc();
|
||||
auto resultType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
Value initElem = createLinalgNeutralElementForReduceOp(
|
||||
rewriter, loc, op, resultType.getElementType());
|
||||
Type elemType = resultType.getElementType();
|
||||
LogicalResult elemTypeCheck =
|
||||
validateReductionElementType(op, elemType, rewriter);
|
||||
if (failed(elemTypeCheck))
|
||||
return elemTypeCheck;
|
||||
|
||||
bool hadErrorCreatingPayload = false;
|
||||
Value generic = torch_to_linalg::createReductionLinalgGeneric(
|
||||
rewriter, loc, tensorOperand, dimSet, keepDim, initElem,
|
||||
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
|
||||
Value result = createLinalgPayloadCalculationForReduceOp(
|
||||
b, loc, payloadArgs, op, operands, resultType.getElementType());
|
||||
if (!result) {
|
||||
hadErrorCreatingPayload = true;
|
||||
return;
|
||||
Value reduceOp =
|
||||
createReductionOp(loc, elemType, op, operands, *opInfo, rewriter);
|
||||
if (!reduceOp)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to create linalg.generic operation for reduction");
|
||||
|
||||
// If this is aten.linalg_vector_norm op, then we need to generate another
|
||||
// linalg.generic op that references the first linalg.generic op.
|
||||
if (auto normOp = dyn_cast<AtenLinalgVectorNormOp>(op)) {
|
||||
AtenLinalgVectorNormOp::Adaptor adaptor(operands);
|
||||
FailureOr<Value> secondReduceOp = createSecondReductionForVectorNormOp(
|
||||
loc, elemType, normOp, adaptor.ord(), reduceOp, *opInfo, rewriter);
|
||||
if (failed(secondReduceOp))
|
||||
return secondReduceOp;
|
||||
reduceOp = *secondReduceOp;
|
||||
}
|
||||
b.create<linalg::YieldOp>(loc, result);
|
||||
});
|
||||
|
||||
if (hadErrorCreatingPayload)
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, generic);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, reduceOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -319,5 +465,6 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality(
|
|||
target.addIllegalOp<AtenSumOp>();
|
||||
target.addIllegalOp<AtenSumDimIntListOp>();
|
||||
target.addIllegalOp<AtenMaxOp>();
|
||||
target.addIllegalOp<AtenLinalgVectorNormOp>();
|
||||
patterns.add<ConvertReductionOp>(typeConverter, context);
|
||||
}
|
||||
|
|
|
@ -35,113 +35,6 @@ template <typename elementType> static bool hasElementType(Value tensor) {
|
|||
return tensorElementType.isa<elementType>();
|
||||
}
|
||||
|
||||
static Value createElementwiseLinalgGeneric(
|
||||
OpBuilder &b, Location loc, ValueRange tensorOperands,
|
||||
Type resultElementType,
|
||||
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
|
||||
// The overall error handling strategy here is best viewed by thinking about
|
||||
// what happens for a single result dimension. This loop not structured that
|
||||
// way because it is hard to create the affine maps for each operand unless
|
||||
// we structure the loop to iterate over tensor operands as the outer loop
|
||||
// instead of inner loop. This pseudocode gives better intuition:
|
||||
// ```
|
||||
// for each result dimension:
|
||||
// for each tensor operand:
|
||||
// if it doesn't even have high enough rank relative to the result:
|
||||
// continue
|
||||
// if it is a static size-1 along this result dimension:
|
||||
// continue
|
||||
// if this is the first tensor operand that didn't continue above:
|
||||
// take its dimension size as the size of the non-broadcasted
|
||||
// traversal along this dimension (this may include a dynamic size-1,
|
||||
// **non-broadcasted** traversal!)
|
||||
// emit error check "if the size does not match the non-broadcasted
|
||||
// traversal size along this dimension, error"
|
||||
// ```
|
||||
SmallVector<int64_t> operandRanks;
|
||||
operandRanks.resize(tensorOperands.size());
|
||||
llvm::transform(tensorOperands, operandRanks.begin(), [](Value tensor) {
|
||||
return tensor.getType().dyn_cast<RankedTensorType>().getRank();
|
||||
});
|
||||
|
||||
auto resultRankIt =
|
||||
std::max_element(operandRanks.begin(), operandRanks.end());
|
||||
assert(resultRankIt != operandRanks.end() && "Unable to get result rank.");
|
||||
int64_t resultRank = *resultRankIt;
|
||||
|
||||
// Initialize the resultShape to all 1's, as a fallback in case
|
||||
// all sizes along that result dimension are statically 1.
|
||||
auto c1 = b.create<arith::ConstantIndexOp>(loc, /*value=*/1);
|
||||
SmallVector<Value> resultShape(resultRank, c1);
|
||||
SmallVector<AffineMap> indexingMaps;
|
||||
for (Value tensorOperand : tensorOperands) {
|
||||
SmallVector<AffineExpr> exprs;
|
||||
auto type = tensorOperand.getType().cast<RankedTensorType>();
|
||||
for (auto size : llvm::enumerate(type.getShape())) {
|
||||
// If the size is statically known to be 1, we don't want any
|
||||
// error guards to be spuriously emitted, since we are specifically
|
||||
// allowing size-1 broadcasts in this case, as they correspond to a
|
||||
// constant-0 indexing map.
|
||||
if (size.value() == 1) {
|
||||
exprs.push_back(b.getAffineConstantExpr(0));
|
||||
continue;
|
||||
}
|
||||
|
||||
// The rank of this operand might be smaller than the overall rank of
|
||||
// the broadcast. Add an offset to correlate it to the correct
|
||||
// dimension of the result.
|
||||
auto resultDim = size.index() + (resultRank - type.getRank());
|
||||
|
||||
// The generated linalg op will now be iterating along the full size
|
||||
// of this dimension. Record that fact.
|
||||
exprs.push_back(b.getAffineDimExpr(resultDim));
|
||||
|
||||
// Now, we need to ensure that such iteration is not going to trigger
|
||||
// undefined behavior, by doing appropriate checks against the current
|
||||
// dimension size.
|
||||
auto currentDimSize = getDimOp(b, loc, tensorOperand, size.index());
|
||||
|
||||
// If the result size of this dimension has so far only hit the
|
||||
// statically-known-to-be-1 case above (i.e., we have not yet assigned a
|
||||
// new Value to `resultShape[resultDim]`), then we have no other dynamic
|
||||
// values to check against, and merely need to record the current
|
||||
// dimension size.
|
||||
if (resultShape[resultDim] == c1) {
|
||||
resultShape[resultDim] = currentDimSize;
|
||||
continue;
|
||||
}
|
||||
|
||||
// We prohibit the size-1 dynamic broadcasting scenario, so just check
|
||||
// for exact equality with the running result size.
|
||||
// This is the check which protects against the undefined behavior of
|
||||
// the generated linalg op in the case of iterating two operands with
|
||||
// dimensions sizes that are expected to match.
|
||||
auto equalToRunning =
|
||||
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
|
||||
resultShape[resultDim], currentDimSize);
|
||||
b.create<cf::AssertOp>(loc, equalToRunning,
|
||||
"mismatched size for broadcast");
|
||||
}
|
||||
indexingMaps.push_back(AffineMap::get(
|
||||
/*dimCount=*/resultRank, /*symbolCount=*/0, exprs, b.getContext()));
|
||||
}
|
||||
|
||||
SmallVector<StringRef> iteratorTypes(resultRank,
|
||||
getParallelIteratorTypeName());
|
||||
// Add the indexing map for the outs init tensor.
|
||||
indexingMaps.push_back(b.getMultiDimIdentityMap(resultRank));
|
||||
|
||||
Value initTensor = b.create<linalg::InitTensorOp>(
|
||||
loc, getAsOpFoldResult(resultShape), resultElementType);
|
||||
return b
|
||||
.create<linalg::GenericOp>(loc,
|
||||
/*resultTensorTypes=*/initTensor.getType(),
|
||||
/*inputs=*/tensorOperands,
|
||||
/*outputs=*/initTensor, indexingMaps,
|
||||
iteratorTypes, bodyBuild)
|
||||
.getResult(0);
|
||||
}
|
||||
|
||||
template <arith::CmpFPredicate fpred, arith::CmpIPredicate iupred,
|
||||
arith::CmpIPredicate ispred>
|
||||
static Value createComparisonTemplate(OpBuilder &b, Location loc, Type type,
|
||||
|
@ -964,7 +857,7 @@ public:
|
|||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
bool hadErrorCreatingPayload = false;
|
||||
Value generic = createElementwiseLinalgGeneric(
|
||||
Value generic = torch_to_linalg::createElementwiseLinalgGeneric(
|
||||
rewriter, loc, tensorOperands, resultType.getElementType(),
|
||||
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
|
||||
Value result = createLinalgPayloadCalculationForElementwiseOp(
|
||||
|
@ -1031,7 +924,7 @@ public:
|
|||
Value zeroVal = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getZeroAttr(elementType));
|
||||
|
||||
Value finalRes = createElementwiseLinalgGeneric(
|
||||
Value finalRes = torch_to_linalg::createElementwiseLinalgGeneric(
|
||||
rewriter, loc, {target}, elementType,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value targetVal = args[0];
|
||||
|
@ -1068,8 +961,9 @@ public:
|
|||
/*inclusive=*/false);
|
||||
DenseSet<int64_t> dimSet(dimsToReduce.begin(), dimsToReduce.end());
|
||||
|
||||
auto opInfo = torch_to_linalg::ReductionOpInfo{false, finalRes, dimSet};
|
||||
finalRes = torch_to_linalg::createReductionLinalgGeneric(
|
||||
rewriter, loc, finalRes, dimSet, /*keepDim=*/false,
|
||||
rewriter, loc, opInfo,
|
||||
/*initElem=*/zeroVal,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value newVal = args[0];
|
||||
|
|
|
@ -98,23 +98,22 @@ Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc,
|
|||
}
|
||||
|
||||
Value torch_to_linalg::createReductionLinalgGeneric(
|
||||
OpBuilder &b, Location loc, Value tensorOperand,
|
||||
const DenseSet<int64_t> &dimSet, bool keepDim, Value initElem,
|
||||
OpBuilder &b, Location loc, const ReductionOpInfo &opInfo, Value initElem,
|
||||
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
|
||||
auto inputType = tensorOperand.getType().cast<RankedTensorType>();
|
||||
auto inputType = opInfo.tensorOperand.getType().cast<RankedTensorType>();
|
||||
|
||||
// Get the result shape by obtaining the size of each
|
||||
// dimension in the input tensor that is not getting reduced.
|
||||
// If `keepDim` is true, the rank of the output tensor
|
||||
// If `opInfo.keepDim` is true, the rank of the output tensor
|
||||
// is kept the same as the rank of the input tensor, and the
|
||||
// reduced dimensions are set to have size 1.
|
||||
auto c1 = b.create<arith::ConstantIndexOp>(loc, /*value=*/1);
|
||||
SmallVector<Value> resultShape;
|
||||
for (int64_t i = 0; i < inputType.getRank(); i++) {
|
||||
auto currentDimSize = b.create<tensor::DimOp>(loc, tensorOperand, i);
|
||||
if (!dimSet.contains(i))
|
||||
auto currentDimSize = b.create<tensor::DimOp>(loc, opInfo.tensorOperand, i);
|
||||
if (!opInfo.dimSet.contains(i))
|
||||
resultShape.push_back(currentDimSize);
|
||||
else if (keepDim)
|
||||
else if (opInfo.keepDim)
|
||||
resultShape.push_back(c1);
|
||||
}
|
||||
|
||||
|
@ -127,11 +126,11 @@ Value torch_to_linalg::createReductionLinalgGeneric(
|
|||
for (auto size : llvm::enumerate(inputType.getShape())) {
|
||||
exprs.push_back(b.getAffineDimExpr(size.index()));
|
||||
|
||||
if (dimSet.contains(size.index())) {
|
||||
if (opInfo.dimSet.contains(size.index())) {
|
||||
iteratorTypes.push_back(getReductionIteratorTypeName());
|
||||
// If `keepDim`, create affine map to the first element
|
||||
// If `opInfo.keepDim`, create affine map to the first element
|
||||
// in the current dimension.
|
||||
if (keepDim)
|
||||
if (opInfo.keepDim)
|
||||
resultExprs.push_back(b.getAffineConstantExpr(0));
|
||||
} else {
|
||||
iteratorTypes.push_back(getParallelIteratorTypeName());
|
||||
|
@ -146,7 +145,114 @@ Value torch_to_linalg::createReductionLinalgGeneric(
|
|||
return b
|
||||
.create<linalg::GenericOp>(
|
||||
loc, /*resultTensorTypes=*/accumulator.getType(),
|
||||
/*inputs=*/tensorOperand,
|
||||
/*inputs=*/opInfo.tensorOperand,
|
||||
/*outputs=*/accumulator, indexingMaps, iteratorTypes, bodyBuild)
|
||||
.getResult(0);
|
||||
}
|
||||
|
||||
Value torch_to_linalg::createElementwiseLinalgGeneric(
|
||||
OpBuilder &b, Location loc, ValueRange tensorOperands,
|
||||
Type resultElementType,
|
||||
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
|
||||
// The overall error handling strategy here is best viewed by thinking about
|
||||
// what happens for a single result dimension. This loop not structured that
|
||||
// way because it is hard to create the affine maps for each operand unless
|
||||
// we structure the loop to iterate over tensor operands as the outer loop
|
||||
// instead of inner loop. This pseudocode gives better intuition:
|
||||
// ```
|
||||
// for each result dimension:
|
||||
// for each tensor operand:
|
||||
// if it doesn't even have high enough rank relative to the result:
|
||||
// continue
|
||||
// if it is a static size-1 along this result dimension:
|
||||
// continue
|
||||
// if this is the first tensor operand that didn't continue above:
|
||||
// take its dimension size as the size of the non-broadcasted
|
||||
// traversal along this dimension (this may include a dynamic size-1,
|
||||
// **non-broadcasted** traversal!)
|
||||
// emit error check "if the size does not match the non-broadcasted
|
||||
// traversal size along this dimension, error"
|
||||
// ```
|
||||
SmallVector<int64_t> operandRanks;
|
||||
operandRanks.resize(tensorOperands.size());
|
||||
llvm::transform(tensorOperands, operandRanks.begin(), [](Value tensor) {
|
||||
return tensor.getType().dyn_cast<RankedTensorType>().getRank();
|
||||
});
|
||||
|
||||
auto resultRankIt =
|
||||
std::max_element(operandRanks.begin(), operandRanks.end());
|
||||
assert(resultRankIt != operandRanks.end() && "Unable to get result rank.");
|
||||
int64_t resultRank = *resultRankIt;
|
||||
|
||||
// Initialize the resultShape to all 1's, as a fallback in case
|
||||
// all sizes along that result dimension are statically 1.
|
||||
auto c1 = b.create<arith::ConstantIndexOp>(loc, /*value=*/1);
|
||||
SmallVector<Value> resultShape(resultRank, c1);
|
||||
SmallVector<AffineMap> indexingMaps;
|
||||
for (Value tensorOperand : tensorOperands) {
|
||||
SmallVector<AffineExpr> exprs;
|
||||
auto type = tensorOperand.getType().cast<RankedTensorType>();
|
||||
for (auto size : llvm::enumerate(type.getShape())) {
|
||||
// If the size is statically known to be 1, we don't want any
|
||||
// error guards to be spuriously emitted, since we are specifically
|
||||
// allowing size-1 broadcasts in this case, as they correspond to a
|
||||
// constant-0 indexing map.
|
||||
if (size.value() == 1) {
|
||||
exprs.push_back(b.getAffineConstantExpr(0));
|
||||
continue;
|
||||
}
|
||||
|
||||
// The rank of this operand might be smaller than the overall rank of
|
||||
// the broadcast. Add an offset to correlate it to the correct
|
||||
// dimension of the result.
|
||||
auto resultDim = size.index() + (resultRank - type.getRank());
|
||||
|
||||
// The generated linalg op will now be iterating along the full size
|
||||
// of this dimension. Record that fact.
|
||||
exprs.push_back(b.getAffineDimExpr(resultDim));
|
||||
|
||||
// Now, we need to ensure that such iteration is not going to trigger
|
||||
// undefined behavior, by doing appropriate checks against the current
|
||||
// dimension size.
|
||||
auto currentDimSize = getDimOp(b, loc, tensorOperand, size.index());
|
||||
|
||||
// If the result size of this dimension has so far only hit the
|
||||
// statically-known-to-be-1 case above (i.e., we have not yet assigned a
|
||||
// new Value to `resultShape[resultDim]`), then we have no other dynamic
|
||||
// values to check against, and merely need to record the current
|
||||
// dimension size.
|
||||
if (resultShape[resultDim] == c1) {
|
||||
resultShape[resultDim] = currentDimSize;
|
||||
continue;
|
||||
}
|
||||
|
||||
// We prohibit the size-1 dynamic broadcasting scenario, so just check
|
||||
// for exact equality with the running result size.
|
||||
// This is the check which protects against the undefined behavior of
|
||||
// the generated linalg op in the case of iterating two operands with
|
||||
// dimensions sizes that are expected to match.
|
||||
auto equalToRunning =
|
||||
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
|
||||
resultShape[resultDim], currentDimSize);
|
||||
b.create<cf::AssertOp>(loc, equalToRunning,
|
||||
"mismatched size for broadcast");
|
||||
}
|
||||
indexingMaps.push_back(AffineMap::get(
|
||||
/*dimCount=*/resultRank, /*symbolCount=*/0, exprs, b.getContext()));
|
||||
}
|
||||
|
||||
SmallVector<StringRef> iteratorTypes(resultRank,
|
||||
getParallelIteratorTypeName());
|
||||
// Add the indexing map for the outs init tensor.
|
||||
indexingMaps.push_back(b.getMultiDimIdentityMap(resultRank));
|
||||
|
||||
Value initTensor = b.create<linalg::InitTensorOp>(
|
||||
loc, getAsOpFoldResult(resultShape), resultElementType);
|
||||
return b
|
||||
.create<linalg::GenericOp>(loc,
|
||||
/*resultTensorTypes=*/initTensor.getType(),
|
||||
/*inputs=*/tensorOperands,
|
||||
/*outputs=*/initTensor, indexingMaps,
|
||||
iteratorTypes, bodyBuild)
|
||||
.getResult(0);
|
||||
}
|
||||
|
|
|
@ -13,6 +13,12 @@ namespace mlir {
|
|||
namespace torch {
|
||||
namespace torch_to_linalg {
|
||||
|
||||
struct ReductionOpInfo {
|
||||
bool keepDim;
|
||||
Value tensorOperand;
|
||||
DenseSet<int64_t> dimSet;
|
||||
};
|
||||
|
||||
// Helper function to get the padding tensor given the padding int values.
|
||||
Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input,
|
||||
SmallVectorImpl<int64_t> &lowPaddingInts,
|
||||
|
@ -33,16 +39,21 @@ Value getOutputDimForConvOps(OpBuilder &b, Location loc, Value in,
|
|||
Value kernelSizeInt, Value strideInt,
|
||||
bool ceilMode = false);
|
||||
|
||||
// Create a reduction of `tensorOperand`, reducing along the dimensions
|
||||
// in `dimSet`. If `keepDim` is true, the output tensor is the same
|
||||
// rank as the `tensorOperand` and reduced dimensions are set to size 1.
|
||||
// `initElem` is the element used to initialize the output tensor
|
||||
// where the reduction will be stored.
|
||||
// Create a reduction of `opInfo.tensorOperand`, reducing along the dimensions
|
||||
// in `opInfo.dimSet`. If `opInfo.keepDim` is true, the output tensor is the
|
||||
// same rank as the `opInfo.tensorOperand` and reduced dimensions are set to
|
||||
// size 1. `initElem` is the element used to initialize the output tensor where
|
||||
// the reduction will be stored.
|
||||
Value createReductionLinalgGeneric(
|
||||
OpBuilder &b, Location loc, Value tensorOperand,
|
||||
const DenseSet<int64_t> &dimSet, bool keepDim, Value initElem,
|
||||
OpBuilder &b, Location loc, const ReductionOpInfo &opInfo, Value initElem,
|
||||
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild);
|
||||
|
||||
// Create a pointwise operation that uses values in `tensorOperands`, such that
|
||||
// the element type of the resulting tensor is `resultElementType`.
|
||||
Value createElementwiseLinalgGeneric(
|
||||
OpBuilder &b, Location loc, ValueRange tensorOperands,
|
||||
Type resultElementType,
|
||||
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild);
|
||||
} // namespace torch_to_linalg
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
|
|
@ -988,6 +988,14 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
if (auto scalarImplicit = dyn_cast<AtenScalarImplicitOp>(op))
|
||||
return visitAtenScalarImplicitOp(scalarImplicit, operands);
|
||||
|
||||
if (auto vectorNorm = dyn_cast<AtenLinalgVectorNormOp>(op)) {
|
||||
Type defaultDtype = operands[0]->getValue().dtype;
|
||||
Type dtype = getDtypeOrDefault(vectorNorm.getContext(), vectorNorm.dtype(),
|
||||
defaultDtype);
|
||||
return visitReductionAlongDimIntListOp(
|
||||
vectorNorm, vectorNorm.dim(), vectorNorm.keepdim(), dtype, operands);
|
||||
}
|
||||
|
||||
// Otherwise, this is an unknown operation. Just mark all results as
|
||||
// having reached a pessimistic fixpoint.
|
||||
return markAllPessimisticFixpoint(op->getResults());
|
||||
|
|
|
@ -3078,6 +3078,29 @@ module {
|
|||
%none = torch.constant.none
|
||||
return %none : !torch.none
|
||||
}
|
||||
func.func @"__torch_mlir_shape_fn.aten.linalg_vector_norm"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.optional<list<int>>, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.list<int> {
|
||||
%none = torch.constant.none
|
||||
%true = torch.constant.bool true
|
||||
%0 = torch.aten.__isnot__ %arg2, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
|
||||
%1 = torch.prim.If %0 -> (!torch.list<int>) {
|
||||
%2 = torch.prim.unchecked_cast %arg2 : !torch.optional<list<int>> -> !torch.list<int>
|
||||
%3 = torch.derefine %arg4 : !torch.optional<int> to !torch.any
|
||||
%4 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.mean_dim(%arg0, %2, %arg3, %3) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
|
||||
torch.prim.If.yield %4 : !torch.list<int>
|
||||
} else {
|
||||
%2 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||
%3 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
torch.prim.Loop %2, %true, init() {
|
||||
^bb0(%arg5: !torch.int):
|
||||
%6 = torch.aten.append.t %3, %arg5 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||
torch.prim.Loop.condition %true, iter()
|
||||
} : (!torch.int, !torch.bool) -> ()
|
||||
%4 = torch.derefine %arg4 : !torch.optional<int> to !torch.any
|
||||
%5 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.mean_dim(%arg0, %3, %arg3, %4) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
|
||||
torch.prim.If.yield %5 : !torch.list<int>
|
||||
}
|
||||
return %1 : !torch.list<int>
|
||||
}
|
||||
}
|
||||
)mlir");
|
||||
#pragma clang diagnostic pop
|
||||
|
|
|
@ -946,6 +946,11 @@ def hacky_get_unknown_dimension_size():
|
|||
def aten〇bincount(self: List[int], weights: Optional[List[int]] = None, minlength: int = 0) -> List[int]:
|
||||
return [hacky_get_unknown_dimension_size()]
|
||||
|
||||
def aten〇linalg_vector_norm(self: List[int], ord: float = 2, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
|
||||
if dim is None:
|
||||
dim = list(range(len(self)))
|
||||
return upstream_shape_helpers.mean_dim(self, dim, keepdim, dtype)
|
||||
|
||||
# ==============================================================================
|
||||
# Shape library generator main().
|
||||
# ==============================================================================
|
||||
|
|
|
@ -365,6 +365,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)")
|
||||
emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)")
|
||||
emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)")
|
||||
emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)")
|
||||
|
||||
# Misc tensor ops.
|
||||
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")
|
||||
|
|
|
@ -383,3 +383,111 @@ class ReduceMaxUnsignedIntModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: ReduceMaxUnsignedIntModule())
|
||||
def ReduceMaxUnsignedIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (3, 4, 5)))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceL1NormModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.linalg.vector_norm(a, dim=0, ord=1)
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceL1NormModule())
|
||||
def ReduceL1NormModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(3, 4, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceL1NormWithDTypeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.linalg.vector_norm(a, dim=0, ord=1, dtype=torch.float64)
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceL1NormWithDTypeModule())
|
||||
def ReduceL1NormWithDTypeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5).to(torch.float32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceL2NormModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.linalg.vector_norm(a, dim=0)
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceL2NormModule())
|
||||
def ReduceL2NormModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(3, 4, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceLN3NormModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.linalg.vector_norm(a, dim=0, ord=-3)
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceLN3NormModule())
|
||||
def ReduceLN3NormModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(3, 4, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceL3NormAllDimsModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.linalg.vector_norm(a, dim=None, ord=3)
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceL3NormAllDimsModule())
|
||||
def ReduceL3NormAllDimsModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(3, 4, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceL3NormKeepDimModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.linalg.vector_norm(a, keepdim=True, ord=3)
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceL3NormKeepDimModule())
|
||||
def ReduceL3NormKeepDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(3, 4, 5))
|
||||
|
|
Loading…
Reference in New Issue