LLVM bump

Major changes: opTrait changed to Trait, selectOp moved to arith dialect
assertOp moved to cf dialect
pull/608/head
Nirvedh 2022-02-12 18:47:12 +00:00 committed by nirvedhmeshram
parent 442ff4605c
commit f8cb32faf0
19 changed files with 180 additions and 98 deletions

View File

@ -20,11 +20,11 @@ include "mlir/Interfaces/ControlFlowInterfaces.td"
// Base class.
//===----------------------------------------------------------------------===//
class TMTensor_PureOp<string mnemonic, list<OpTrait> traits = []> :
class TMTensor_PureOp<string mnemonic, list<Trait> traits = []> :
Op<TMTensor_Dialect, mnemonic, traits> {
}
class TMTensor_Op<string mnemonic, list<OpTrait> traits = []> :
class TMTensor_Op<string mnemonic, list<Trait> traits = []> :
TMTensor_PureOp<mnemonic, !listconcat(traits,
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,

View File

@ -15,7 +15,6 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/Attributes.h"

@ -1 +1 @@
Subproject commit 84fe34a0b7fdd7bbf179981d1583693d5d5ec68b
Subproject commit bfc6fbfb65f6d490e6a1ba4eb6c734d2b4494dd1

View File

@ -38,7 +38,7 @@ def Torch_Dialect : Dialect {
let hasConstantMaterializer = 1;
}
class TorchOpTrait<string name> : OpTrait, NativeTrait<"", ""> {
class TorchOpTrait<string name> : NativeOpTrait<""> {
let trait = name;
let cppNamespace = "::mlir::torch::Torch::OpTrait";
}

View File

@ -18,7 +18,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
class Torch_Op<string mnemonic, list<OpTrait> traits = []>
class Torch_Op<string mnemonic, list<Trait> traits = []>
: Op<Torch_Dialect, mnemonic, traits> {
}
@ -503,8 +503,8 @@ def Torch_PrimIfOp : Torch_Op<"prim.If", [
let arguments = (ins Torch_BoolType:$condition);
let results = (outs Variadic<AnyTorchType>:$results);
let regions = (region SizedRegion<1>:$thenRegion, SizedRegion<1>:$elseRegion);
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parsePrimIfOp(parser, result); }];
// Indicate that the operation has a custom parser and printer method.
let hasCustomAssemblyFormat = 1;
let verifier = [{ return RegionBranchOpInterface::verifyTypes(*this); }];
let hasCanonicalizer = 1;
}
@ -594,8 +594,8 @@ def Torch_ConstantIntOp : Torch_Op<"constant.int",
let results = (outs
Torch_IntType:$result
);
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parseConstantIntOp(parser, result); }];
// Indicate that the operation has a custom parser and printer method.
let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
}

View File

@ -30,7 +30,7 @@ class Torch_TypeWithContainedType<string name, string typeMnemonic> : Torch_Type
}];
let parser = [{
if (parser.parseLess())
if ($_parser.parseLess())
return Type();
Type containedType;
if ($_parser.parseType(containedType))
@ -39,7 +39,6 @@ class Torch_TypeWithContainedType<string name, string typeMnemonic> : Torch_Type
return Type();
return get($_ctxt, containedType);
}];
let builders = [
TypeBuilderWithInferredContext<(ins "::mlir::Type":$containedType), [{
return Base::get(containedType.getContext(), containedType);
@ -61,9 +60,9 @@ def Torch_NnModuleType : Torch_Type<"NnModule", "nn.Module"> {
}];
let parser = [{
if (parser.parseLess())
if ($_parser.parseLess())
return Type();
std::string className;
std::string className;
if ($_parser.parseOptionalString(&className))
return Type();
if ($_parser.parseGreater())
@ -349,7 +348,7 @@ def Torch_DictType : Torch_Type<"Dict", "dict"> {
}];
let parser = [{
if (parser.parseLess())
if ($_parser.parseLess())
return Type();
Type keyType;
if ($_parser.parseType(keyType))
@ -363,7 +362,6 @@ def Torch_DictType : Torch_Type<"Dict", "dict"> {
return Type();
return get($_ctxt, keyType, valueType);
}];
let builders = [
TypeBuilderWithInferredContext<(ins "::mlir::Type":$keyType,
"::mlir::Type":$valueType), [{

View File

@ -17,7 +17,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
include "torch-mlir/Dialect/Torch/IR/TorchTypes.td"
include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionBase.td"
class TorchConversion_Op<string mnemonic, list<OpTrait> traits = []>
class TorchConversion_Op<string mnemonic, list<Trait> traits = []>
: Op<TorchConversion_Dialect, mnemonic, traits> {
}

View File

@ -25,6 +25,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createMungeCallingConventionsPass();
std::unique_ptr<OperationPass<FuncOp>> createExpandOpsForLLVMPass();
std::unique_ptr<OperationPass<ModuleOp>> createInsertRngGlobalsPass();
std::unique_ptr<OperationPass<FuncOp>> createMungeMemrefCopyPass();
} // namespace RefBackend
} // namespace torch
} // namespace mlir

View File

@ -29,4 +29,10 @@ def ExpandOpsForLLVM : Pass<"refback-expand-ops-for-llvm", "FuncOp"> {
let constructor = "mlir::torch::RefBackend::createExpandOpsForLLVMPass();";
}
def MungeMemrefCopy : Pass<"refback-munge-memref-copy", "FuncOp"> {
let summary = "Munge memref.copy to linalg.copy";
let constructor = "mlir::torch::RefBackend::createMungeMemrefCopyPass();";
let dependentDialects = ["memref::MemRefDialect"];
}
#endif // TORCHMLIR_REFBACKEND_PASSES

View File

@ -11,8 +11,10 @@
#include "../PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Traits.h"
@ -78,7 +80,8 @@ static Value toPositiveDimDynamic(OpBuilder &b, Location loc, Value dim,
b.create<arith::ConstantOp>(loc, b.getZeroAttr(inputRank.getType()));
Value predDimGEZero =
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, dim, cst0);
Value dimInt = b.create<SelectOp>(loc, predDimGEZero, dim, dimAddInputRank);
Value dimInt =
b.create<arith::SelectOp>(loc, predDimGEZero, dim, dimAddInputRank);
return dimInt;
}
@ -91,12 +94,12 @@ static void assertIsValidDim(OpBuilder &b, Location loc, Value dim,
b.create<arith::ConstantOp>(loc, b.getZeroAttr(inputRank.getType()));
Value predGEZero =
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, dim, cst0);
b.create<AssertOp>(loc, predGEZero,
b.getStringAttr("dim must be greater or equal to zero"));
b.create<cf::AssertOp>(
loc, predGEZero, b.getStringAttr("dim must be greater or equal to zero"));
Value predLTInputRank =
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, dim, inputRank);
b.create<AssertOp>(loc, predLTInputRank,
b.getStringAttr("dim must be smaller than inputRank"));
b.create<cf::AssertOp>(loc, predLTInputRank,
b.getStringAttr("dim must be smaller than inputRank"));
}
// Hack to deal with the Torch list type arguments which is not supported end
@ -147,8 +150,8 @@ static void checkDimEqualHelper(OpBuilder &b, Location loc, Value lhsDim,
Value rhsDimInt = rhsType.isIndex() ? castIndexToInt(b, loc, rhsDim) : rhsDim;
Value contractingDimEqual = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, lhsDimInt, rhsDimInt);
b.create<AssertOp>(loc, contractingDimEqual,
b.getStringAttr("mismatching contracting dimension"));
b.create<cf::AssertOp>(loc, contractingDimEqual,
b.getStringAttr("mismatching contracting dimension"));
}
static SmallVector<Value> getTensorSizesUntilDim(OpBuilder &b, Location loc,
@ -471,8 +474,8 @@ public:
rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(intType, 1));
Value groupEqual1 = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, groups, c1);
rewriter.create<AssertOp>(loc, groupEqual1,
rewriter.getStringAttr("expect groups to be 1"));
rewriter.create<cf::AssertOp>(
loc, groupEqual1, rewriter.getStringAttr("expect groups to be 1"));
// Pad the input tensor according to padding.
SmallVector<int64_t, 4> paddingIncludingNC = {0, 0};
@ -581,15 +584,16 @@ static void createLinalgPayloadCalculationForGatherOps(
Value indexLTInputDim = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, index,
castIndexToInt(b, loc, getDimOp(b, loc, input, dim)));
b.create<AssertOp>(loc, indexLTInputDim,
b.getStringAttr("index must be smaller than dim size"));
b.create<cf::AssertOp>(
loc, indexLTInputDim,
b.getStringAttr("index must be smaller than dim size"));
// Assert index >= 0
Value cst0 = b.create<arith::ConstantOp>(loc, b.getZeroAttr(index.getType()));
Value indexGEThanZero =
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, index, cst0);
b.create<AssertOp>(loc, indexGEThanZero,
b.getStringAttr("index must be larger or equal to 0"));
b.create<cf::AssertOp>(loc, indexGEThanZero,
b.getStringAttr("index must be larger or equal to 0"));
Value extract = b.create<tensor::ExtractOp>(loc, input, indices);
b.create<linalg::YieldOp>(loc, extract);
@ -858,7 +862,7 @@ public:
Value rhsDim1 = rewriter.create<tensor::DimOp>(loc, rhs, 1);
Value contractingDimEqual = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, lhsDim1, rhsDim0);
rewriter.create<AssertOp>(
rewriter.create<cf::AssertOp>(
loc, contractingDimEqual,
rewriter.getStringAttr(
"mismatching contracting dimension for torch.aten.mm"));
@ -1196,7 +1200,7 @@ public:
loc, input, ValueRange{indI, indTarget});
Value negate =
rewriter.create<arith::NegFOp>(loc, elementType, result);
Value selectFinal = rewriter.create<mlir::SelectOp>(
Value selectFinal = rewriter.create<arith::SelectOp>(
loc, cmpEq, zeroVal, negate);
b.create<linalg::YieldOp>(loc, selectFinal);
})
@ -1302,7 +1306,7 @@ public:
rewriter.create<arith::AndIOp>(loc, cmpEq, cmpNe);
Value negate =
rewriter.create<arith::NegFOp>(loc, elementType, args[1]);
Value selectFinal = rewriter.create<mlir::SelectOp>(
Value selectFinal = rewriter.create<arith::SelectOp>(
loc, finalPredicate, negate, zeroVal);
b.create<linalg::YieldOp>(loc, selectFinal);
})
@ -1376,7 +1380,7 @@ public:
Value biasDim0 = getDimOp(rewriter, loc, bias, 0);
Value contractingDimEqual = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, inputDim1, weightDim1);
rewriter.create<AssertOp>(
rewriter.create<cf::AssertOp>(
loc, contractingDimEqual,
rewriter.getStringAttr(
"mismatching contracting dimension for aten.linear"));
@ -1384,7 +1388,7 @@ public:
// In the static-size-1 case, we will not emit this check at all.
Value biasSizeCorrect = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, weightDim0, biasDim0);
rewriter.create<AssertOp>(
rewriter.create<cf::AssertOp>(
loc, biasSizeCorrect,
rewriter.getStringAttr("mismatching bias size for aten.linear"));
@ -1500,31 +1504,31 @@ static Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar,
if (auto dtypeFloat = dtype.dyn_cast<mlir::FloatType>()) {
if (auto scalarFloat = scalarType.dyn_cast<mlir::FloatType>()) {
if (scalarFloat.getWidth() > dtypeFloat.getWidth())
return b.create<arith::TruncFOp>(loc, scalar, dtype);
return b.create<arith::TruncFOp>(loc, dtype, scalar);
// Only scalarFloat width < dtypeFloat width can reach here.
return b.create<arith::ExtFOp>(loc, scalar, dtype);
return b.create<arith::ExtFOp>(loc, dtype, scalar);
}
assert(scalarType.isa<mlir::IntegerType>());
if (scalarType.isSignlessInteger(1))
return b.create<arith::UIToFPOp>(loc, scalar, dtype);
return b.create<arith::UIToFPOp>(loc, dtype, scalar);
// It's safe to use SIToFPOp because ui8/si8 are the only ones where
// unsigned handling is needed, and we checked for that case above.
return b.create<arith::SIToFPOp>(loc, scalar, dtype);
return b.create<arith::SIToFPOp>(loc, dtype, scalar);
}
if (auto dtypeInteger = dtype.dyn_cast<mlir::IntegerType>()) {
if (auto scalarFloat = scalarType.dyn_cast<mlir::FloatType>())
return b.create<arith::FPToSIOp>(loc, scalar, dtype);
return b.create<arith::FPToSIOp>(loc, dtype, scalar);
assert(scalarType.isa<mlir::IntegerType>());
auto scalarInteger = scalarType.cast<mlir::IntegerType>();
if (scalarInteger.getWidth() > dtypeInteger.getWidth())
return b.create<arith::TruncIOp>(loc, scalar, dtype);
return b.create<arith::TruncIOp>(loc, dtype, scalar);
if (scalarType.isSignlessInteger(1))
return b.create<arith::ExtUIOp>(loc, scalar, dtype);
return b.create<arith::ExtUIOp>(loc, dtype, scalar);
// Only scalarInteger width < dtypeInteger width can reach here.
// It's safe to use ExtSIOp here because ui8/si8 are the only ones where
// unsigned handling is needed, and we checked for that case above.
return b.create<arith::ExtSIOp>(loc, scalar, dtype);
return b.create<arith::ExtSIOp>(loc, dtype, scalar);
}
llvm_unreachable("convertScalarToDtype should handle all the types");
@ -1595,7 +1599,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
payloadArgs[0], constZero);
return b.create<SelectOp>(loc, pred, payloadArgs[0], constZero);
return b.create<arith::SelectOp>(loc, pred, payloadArgs[0], constZero);
}
if (auto lrelu = dyn_cast<AtenLeakyReluOp>(op)) {
if (!lrelu.getType()
@ -1610,8 +1614,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
payloadArgs[0], constZero);
Value positivePart = b.create<SelectOp>(loc, pred, payloadArgs[0], constZero);
Value negativePart = b.create<SelectOp>(loc, pred, constZero, payloadArgs[0]);
Value positivePart =
b.create<arith::SelectOp>(loc, pred, payloadArgs[0], constZero);
Value negativePart =
b.create<arith::SelectOp>(loc, pred, constZero, payloadArgs[0]);
Value scale = convertScalarToDtype(b, loc, operands[1], elementType);
Value scaledNegativePart = b.create<arith::MulFOp>(loc, negativePart, scale);
return b.create<arith::AddFOp>(loc, positivePart, scaledNegativePart);
@ -1990,7 +1996,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
.getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[2], dtype);
return b.create<SelectOp>(loc, payloadArgs[0], lhs, rhs);
return b.create<arith::SelectOp>(loc, payloadArgs[0], lhs, rhs);
}
if (auto lerp = dyn_cast<AtenLerpTensorOp>(op)) {
@ -2019,7 +2025,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
}
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
payloadArgs[0], payloadArgs[1]);
return b.create<SelectOp>(loc, pred, payloadArgs[0], payloadArgs[1]);
return b.create<arith::SelectOp>(loc, pred, payloadArgs[0], payloadArgs[1]);
}
if (auto maximum = dyn_cast<AtenMaximumOp>(op)) {
if (!maximum.getType()
@ -2031,7 +2037,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
}
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
payloadArgs[0], payloadArgs[1]);
return b.create<SelectOp>(loc, pred, payloadArgs[0], payloadArgs[1]);
return b.create<arith::SelectOp>(loc, pred, payloadArgs[0], payloadArgs[1]);
}
if (auto clamp = dyn_cast<AtenClampOp>(op)) {
Type dtype = converter->convertType(clamp.getType())
@ -2054,13 +2060,13 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
auto minPromoted = convertScalarToDtype(b, loc, min, dtype);
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
result, minPromoted);
result = b.create<SelectOp>(loc, pred, minPromoted, result);
result = b.create<arith::SelectOp>(loc, pred, minPromoted, result);
}
if (!max.getType().isa<Torch::NoneType>()) {
auto maxPromoted = convertScalarToDtype(b, loc, max, dtype);
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
result, maxPromoted);
result = b.create<SelectOp>(loc, pred, maxPromoted, result);
result = b.create<arith::SelectOp>(loc, pred, maxPromoted, result);
}
return result;
}
@ -2126,7 +2132,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 0.0));
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ONE,
payloadArgs[0], zero);
b.create<AssertOp>(
b.create<cf::AssertOp>(
loc, pred, b.getStringAttr("unimplemented: tensor with zero element"));
auto one =
@ -2152,7 +2158,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
else
predicate = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle, self,
threshold);
return b.create<SelectOp>(loc, predicate, value, self);
return b.create<arith::SelectOp>(loc, predicate, value, self);
}
if (auto thresholdBackward = dyn_cast<AtenThresholdBackwardOp>(op)) {
// The approach used here is as follows:
@ -2174,7 +2180,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
else
predicate = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle, self,
threshold);
return b.create<SelectOp>(loc, predicate, constantZero, grad);
return b.create<arith::SelectOp>(loc, predicate, constantZero, grad);
}
op->emitError("unimplemented lowering in "
@ -2347,9 +2353,9 @@ public:
if (inElementType.isa<mlir::FloatType>())
predicate = rewriter.create<arith::CmpFOp>(
nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
auto resultMax = rewriter.create<mlir::SelectOp>(nestedLoc, predicate,
newValue, oldValue);
auto resultIndex = rewriter.create<mlir::SelectOp>(
auto resultMax = rewriter.create<arith::SelectOp>(
nestedLoc, predicate, newValue, oldValue);
auto resultIndex = rewriter.create<arith::SelectOp>(
nestedLoc, predicate, newIndex, oldIndex);
nestedBuilder.create<linalg::YieldOp>(
nestedLoc, ValueRange({resultMax, resultIndex}));
@ -2487,8 +2493,8 @@ struct ConvertElementwiseOp : ConversionPattern {
auto equalToRunning = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, resultShape[resultDim],
currentDimSize);
rewriter.create<AssertOp>(loc, equalToRunning,
"mismatched size for broadcast");
rewriter.create<cf::AssertOp>(loc, equalToRunning,
"mismatched size for broadcast");
}
indexingMaps.push_back(AffineMap::get(
/*dimCount=*/resultRank, /*symbolCount=*/0, exprs, getContext()));
@ -2703,7 +2709,7 @@ public:
loc, IntegerAttr::get(rewriter.getIntegerType(1), 0));
Value ceilModeFalse = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, ceilMode, falseValue);
rewriter.create<AssertOp>(
rewriter.create<cf::AssertOp>(
loc, ceilModeFalse,
rewriter.getStringAttr("only ceil_mode false is supported"));
@ -3155,7 +3161,7 @@ public:
Value dimSize = getDimOp(rewriter, loc, input, i);
Value dimSizeNotOne = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ne, dimSize, one);
rewriter.create<AssertOp>(
rewriter.create<cf::AssertOp>(
loc, dimSizeNotOne,
rewriter.getStringAttr(
"unimplemented: size 1 dynamic dimension is not supported"));
@ -3489,12 +3495,12 @@ public:
loc, rewriter.getZeroAttr(dimSizeAsInt.getType()));
Value predDimSltZero = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, startOrEndToPositive, cst0);
Value startOrEndAtLeastZero = rewriter.create<SelectOp>(
Value startOrEndAtLeastZero = rewriter.create<arith::SelectOp>(
loc, predDimSltZero, cst0, startOrEndToPositive);
// startOrEnd > dimSizeAsInt ? dimSizeAsInt : startOrEnd
Value startOrEndSgtDimSize = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sgt, startOrEndAtLeastZero, dimSizeAsInt);
Value startOrEndBoundedByDimSize = rewriter.create<SelectOp>(
Value startOrEndBoundedByDimSize = rewriter.create<arith::SelectOp>(
loc, startOrEndSgtDimSize, dimSizeAsInt, startOrEndAtLeastZero);
return castIntToIndex(rewriter, loc, startOrEndBoundedByDimSize);
@ -3509,7 +3515,7 @@ public:
// end >= start ? end : start
Value endSgeStart = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, end, start);
end = rewriter.create<SelectOp>(loc, endSgeStart, end, start);
end = rewriter.create<arith::SelectOp>(loc, endSgeStart, end, start);
int64_t step;
if (!matchPattern(op.step(), m_TorchConstantInt(&step))) {
@ -3851,7 +3857,7 @@ public:
if (i < diff) {
Value isValid = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, shapeValue, zero);
rewriter.create<AssertOp>(
rewriter.create<cf::AssertOp>(
loc, isValid,
rewriter.getStringAttr(
"negative values not allowed in new dimensions"));
@ -3864,7 +3870,7 @@ public:
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
Value isNegative = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, shapeValue, zero);
Value select = rewriter.create<SelectOp>(
Value select = rewriter.create<arith::SelectOp>(
loc, isNegative, one, castIntToIndex(rewriter, loc, shapeValue));
outShape.push_back(select);
outExpr.push_back(mlir::getAffineConstantExpr(0, context));
@ -3878,7 +3884,7 @@ public:
loc, arith::CmpIPredicate::eq, castIndexToInt(rewriter, loc, dim),
shapeValue);
Value isValid = rewriter.create<arith::OrIOp>(loc, isNegative, isEqual);
rewriter.create<AssertOp>(
rewriter.create<cf::AssertOp>(
loc, isValid,
rewriter.getStringAttr(
"only broadcasting singleton dimensions supported"));
@ -4490,6 +4496,7 @@ public:
registry.insert<StandardOpsDialect>();
registry.insert<tensor::TensorDialect>();
registry.insert<arith::ArithmeticDialect>();
registry.insert<cf::ControlFlowDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry);
}
@ -4497,8 +4504,8 @@ public:
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
math::MathDialect, tensor::TensorDialect,
arith::ArithmeticDialect>();
cf::ControlFlowDialect, math::MathDialect,
tensor::TensorDialect, arith::ArithmeticDialect>();
target.addLegalOp<GetNextSeedOp>();
TypeConverter typeConverter;

View File

@ -10,6 +10,7 @@
#include "torch-mlir/Conversion/TorchToStd/TorchToStd.h"
#include "../PassDetail.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Traits.h"
@ -52,8 +53,8 @@ public:
LogicalResult
matchAndRewrite(RuntimeAssertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<AssertOp>(op, adaptor.condition(),
adaptor.message());
rewriter.replaceOpWithNewOp<cf::AssertOp>(op, adaptor.condition(),
adaptor.message());
return success();
}
};
@ -160,6 +161,7 @@ public:
registry.insert<StandardOpsDialect>();
registry.insert<arith::ArithmeticDialect>();
registry.insert<tensor::TensorDialect>();
registry.insert<cf::ControlFlowDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry);
}
@ -167,7 +169,8 @@ public:
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<Torch::TorchDialect, StandardOpsDialect,
arith::ArithmeticDialect, tensor::TensorDialect>();
arith::ArithmeticDialect, tensor::TensorDialect,
cf::ControlFlowDialect>();
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });

View File

@ -290,7 +290,7 @@ PrimLoopConditionOp::getMutableSuccessorOperands(Optional<unsigned> index) {
// PrimIfOp
//===----------------------------------------------------------------------===//
static ParseResult parsePrimIfOp(OpAsmParser &parser, OperationState &result) {
ParseResult PrimIfOp::parse(OpAsmParser &parser, OperationState &result) {
// Create the regions.
result.regions.reserve(2);
Region *thenRegion = result.addRegion();
@ -319,14 +319,14 @@ static ParseResult parsePrimIfOp(OpAsmParser &parser, OperationState &result) {
return success();
}
static void print(OpAsmPrinter &p, PrimIfOp op) {
p << " " << op.condition();
p << " -> (" << op.getResultTypes() << ") ";
p.printRegion(op.thenRegion(), /*printEntryBlockArgs=*/false);
void PrimIfOp::print(OpAsmPrinter &p) {
p << " " << condition();
p << " -> (" << getResultTypes() << ") ";
p.printRegion(thenRegion(), /*printEntryBlockArgs=*/false);
p << " else ";
p.printRegion(op.elseRegion(), /*printEntryBlockArgs=*/false);
p.printRegion(elseRegion(), /*printEntryBlockArgs=*/false);
p.printOptionalAttrDict(op->getAttrs());
p.printOptionalAttrDict((*this)->getAttrs());
}
void PrimIfOp::getSuccessorRegions(Optional<unsigned> index,
@ -938,8 +938,7 @@ void ConstantDeviceOp::getAsmResultNames(
// ConstantIntOp
//===----------------------------------------------------------------------===//
static ParseResult parseConstantIntOp(OpAsmParser &parser,
OperationState &result) {
ParseResult ConstantIntOp::parse(OpAsmParser &parser, OperationState &result) {
Builder builder(result.getContext());
result.addTypes(builder.getType<Torch::IntType>());
if (parser.parseOptionalAttrDict(result.attributes))
@ -951,10 +950,10 @@ static ParseResult parseConstantIntOp(OpAsmParser &parser,
return success();
}
static void print(OpAsmPrinter &p, Torch::ConstantIntOp op) {
void ConstantIntOp::print(OpAsmPrinter &p) {
p << " ";
p << op.value().getSExtValue();
p.printOptionalAttrDict(op->getAttrs(), {"value"});
p << value().getSExtValue();
p.printOptionalAttrDict((*this)->getAttrs(), {"value"});
}
OpFoldResult Torch::ConstantIntOp::fold(ArrayRef<Attribute> operands) {

View File

@ -9,8 +9,10 @@
#include "PassDetail.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Transforms/DialectConversion.h"
@ -68,9 +70,8 @@ class VerifyLinalgOnTensorsBackendContractPass
target.addDynamicallyLegalDialect<linalg::LinalgDialect>(opHasLegalTypes);
target.addDynamicallyLegalDialect<tensor::TensorDialect>(opHasLegalTypes);
target.addDynamicallyLegalDialect<AffineDialect>(opHasLegalTypes);
target.addDynamicallyLegalDialect<cf::ControlFlowDialect>(opHasLegalTypes);
// AssertOp is used to terminate the program for error guards.
target.addLegalOp<AssertOp>();
// ConstantOp is used for tensors and for scalars.
target.addDynamicallyLegalOp<arith::ConstantOp>(opHasLegalTypes);

View File

@ -16,12 +16,15 @@
#include "PassDetail.h"
#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Approximation.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include "torch-mlir/RefBackend/Passes.h"
#include <numeric>
#include <set>
@ -136,7 +139,7 @@ static LogicalResult mungeFunction(
if (!isArgMemRefTypeValid(type))
return emitError(arg.getLoc(),
"argument must be a memref of f32, f64, i32, i64, i1");
auto cast = b.create<memref::CastOp>(arg.getLoc(), arg, type);
auto cast = b.create<memref::CastOp>(arg.getLoc(), type, arg);
arg.replaceAllUsesExcept(cast, cast);
arg.setType(getAbiTypeForMemRef(type));
newArgTypes.push_back(arg.getType());
@ -159,7 +162,7 @@ static LogicalResult mungeFunction(
// Cast to unranked memref type before sending it as a function
// argument.
retVal = b.create<memref::CastOp>(
op.getLoc(), retVal, getAbiTypeForMemRef(types[en.index()]));
op.getLoc(), getAbiTypeForMemRef(types[en.index()]), retVal);
}
retTypes.push_back(retType);
retVals.push_back(retVal);
@ -354,3 +357,64 @@ std::unique_ptr<OperationPass<FuncOp>>
mlir::torch::RefBackend::createExpandOpsForLLVMPass() {
return std::make_unique<ExpandOpsForLLVM>();
}
//===----------------------------------------------------------------------===//
// MungeMemrefCopy
//===----------------------------------------------------------------------===//
Operation *createLinalgCopyOp(OpBuilder &b, Location loc, Value from,
Value to) {
auto memrefTypeFrom = from.getType().cast<MemRefType>();
auto memrefTypeTo = to.getType().cast<MemRefType>();
(void)memrefTypeFrom;
assert(memrefTypeFrom && memrefTypeTo &&
memrefTypeFrom.getRank() == memrefTypeTo.getRank());
AffineMap id =
AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext());
SmallVector<StringRef> iteratorTypes(memrefTypeTo.getRank(),
getParallelIteratorTypeName());
return b.create<linalg::GenericOp>(
loc,
/*inputs=*/from,
/*outputs=*/to,
/*indexingMaps=*/llvm::makeArrayRef({id, id}),
/*iteratorTypes=*/iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args.front());
});
}
namespace {
class MemrefCopyOpToLinalg : public OpRewritePattern<memref::CopyOp> {
using OpRewritePattern<memref::CopyOp>::OpRewritePattern;
LogicalResult matchAndRewrite(memref::CopyOp copyOp,
PatternRewriter &rewriter) const override {
Operation *linalgCopy = createLinalgCopyOp(
rewriter, copyOp.getLoc(), copyOp.source(), copyOp.target());
rewriter.replaceOp(copyOp, linalgCopy->getResults());
return success();
}
};
class MungeMemrefCopy : public MungeMemrefCopyBase<MungeMemrefCopy> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect>();
}
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(&getContext());
patterns.insert<MemrefCopyOpToLinalg>(context);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
mlir::torch::RefBackend::createMungeMemrefCopyPass() {
return std::make_unique<MungeMemrefCopy>();
}

View File

@ -165,12 +165,12 @@ class RefBackendInvoker:
LOWERING_PIPELINE = ",".join([
# Bufferize.
"tensor-constant-bufferize",
"builtin.func(scf-bufferize)",
"builtin.func(linalg-bufferize)",
"builtin.func(std-bufferize)",
"builtin.func(tensor-bufferize)",
"builtin.func(refback-munge-memref-copy)",
"func-bufferize",
"arith-bufferize",
"builtin.func(tensor-bufferize)",
"builtin.func(finalizing-bufferize)",
# Munge to make it ExecutionEngine compatible.
# Specifically, we rewrite calling convention boundaries to be in terms
@ -185,12 +185,15 @@ LOWERING_PIPELINE = ",".join([
# Lower to LLVM
"builtin.func(convert-linalg-to-loops)",
"builtin.func(lower-affine)",
"builtin.func(convert-scf-to-std)",
"convert-scf-to-cf",
"builtin.func(refback-expand-ops-for-llvm)",
"builtin.func(arith-expand)",
"builtin.func(convert-math-to-llvm)",
"convert-linalg-to-llvm",
"convert-memref-to-llvm",
"builtin.func(convert-arith-to-llvm)",
"convert-std-to-llvm",
"convert-cf-to-llvm",
"reconcile-unrealized-casts",
])

View File

@ -42,7 +42,7 @@ func private @caller(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// expected-error @+1 {{unimplemented}}
func @called(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%ctrue = arith.constant true
cond_br %ctrue, ^bb1, ^bb2
cf.cond_br %ctrue, ^bb1, ^bb2
^bb1:
return %arg0 : tensor<*xf32>
^bb2:

View File

@ -343,7 +343,7 @@ builtin.func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> (!torch.vtensor, !torch.v
// CHECK: return %[[CAST]] : !torch.vtensor
builtin.func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor {
%cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
br ^bb1(%cast: !torch.vtensor)
cf.br ^bb1(%cast: !torch.vtensor)
^bb1(%arg1: !torch.vtensor):
%1 = torch.aten.tanh %arg1 : !torch.vtensor -> !torch.vtensor
return %1 : !torch.vtensor

View File

@ -12,11 +12,11 @@ func @identity(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
// CHECK-LABEL: func @block_arguments(
// CHECK-SAME: %[[ARG:.*]]: tensor<f32>) -> tensor<f32> {
// CHECK: br ^bb1(%[[ARG]] : tensor<f32>)
// CHECK: cf.br ^bb1(%[[ARG]] : tensor<f32>)
// CHECK: ^bb1(%[[BBARG:.*]]: tensor<f32>):
// CHECK: return %[[BBARG]] : tensor<f32>
func @block_arguments(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
br ^bb1(%arg0: !torch.vtensor<[],f32>)
cf.br ^bb1(%arg0: !torch.vtensor<[],f32>)
^bb1(%bbarg: !torch.vtensor<[],f32>):
return %bbarg : !torch.vtensor<[],f32>
}
@ -55,7 +55,7 @@ func @unconverted_op_in_body() -> !torch.vtensor<[],f32> {
// update all terminators and issue an error if that is not possible.
func @unable_to_update_terminator(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
%0 = arith.constant true
cond_br %0, ^bb1(%arg0: !torch.vtensor<[],f32>), ^bb2(%arg0: !torch.vtensor<[],f32>)
cf.cond_br %0, ^bb1(%arg0: !torch.vtensor<[],f32>), ^bb2(%arg0: !torch.vtensor<[],f32>)
^bb1(%bbarg0: !torch.vtensor<[],f32>):
// expected-error @+1 {{failed to legalize operation 'test.terminator'}}
"test.terminator"() : () -> ()

View File

@ -10,7 +10,7 @@ func @mm(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
%2 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
%3 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
%4 = arith.cmpi eq, %1, %2 : index
assert %4, "mismatching contracting dimension for aten.mm"
cf.assert %4, "mismatching contracting dimension for aten.mm"
%5 = linalg.init_tensor [%0, %3] : tensor<?x?xf32>
%6 = linalg.fill(%cst, %5) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%7 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%6 : tensor<?x?xf32>) -> tensor<?x?xf32>