mirror of https://github.com/llvm/torch-mlir
LLVM bump
Major changes: opTrait changed to Trait, selectOp moved to arith dialect assertOp moved to cf dialectpull/608/head
parent
442ff4605c
commit
f8cb32faf0
|
@ -20,11 +20,11 @@ include "mlir/Interfaces/ControlFlowInterfaces.td"
|
|||
// Base class.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class TMTensor_PureOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
class TMTensor_PureOp<string mnemonic, list<Trait> traits = []> :
|
||||
Op<TMTensor_Dialect, mnemonic, traits> {
|
||||
}
|
||||
|
||||
class TMTensor_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
class TMTensor_Op<string mnemonic, list<Trait> traits = []> :
|
||||
TMTensor_PureOp<mnemonic, !listconcat(traits,
|
||||
[AttrSizedOperandSegments,
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/StandardOps/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit 84fe34a0b7fdd7bbf179981d1583693d5d5ec68b
|
||||
Subproject commit bfc6fbfb65f6d490e6a1ba4eb6c734d2b4494dd1
|
|
@ -38,7 +38,7 @@ def Torch_Dialect : Dialect {
|
|||
let hasConstantMaterializer = 1;
|
||||
}
|
||||
|
||||
class TorchOpTrait<string name> : OpTrait, NativeTrait<"", ""> {
|
||||
class TorchOpTrait<string name> : NativeOpTrait<""> {
|
||||
let trait = name;
|
||||
let cppNamespace = "::mlir::torch::Torch::OpTrait";
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td"
|
|||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
|
||||
class Torch_Op<string mnemonic, list<OpTrait> traits = []>
|
||||
class Torch_Op<string mnemonic, list<Trait> traits = []>
|
||||
: Op<Torch_Dialect, mnemonic, traits> {
|
||||
}
|
||||
|
||||
|
@ -503,8 +503,8 @@ def Torch_PrimIfOp : Torch_Op<"prim.If", [
|
|||
let arguments = (ins Torch_BoolType:$condition);
|
||||
let results = (outs Variadic<AnyTorchType>:$results);
|
||||
let regions = (region SizedRegion<1>:$thenRegion, SizedRegion<1>:$elseRegion);
|
||||
let printer = [{ return ::print(p, *this); }];
|
||||
let parser = [{ return ::parsePrimIfOp(parser, result); }];
|
||||
// Indicate that the operation has a custom parser and printer method.
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let verifier = [{ return RegionBranchOpInterface::verifyTypes(*this); }];
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
@ -594,8 +594,8 @@ def Torch_ConstantIntOp : Torch_Op<"constant.int",
|
|||
let results = (outs
|
||||
Torch_IntType:$result
|
||||
);
|
||||
let printer = [{ return ::print(p, *this); }];
|
||||
let parser = [{ return ::parseConstantIntOp(parser, result); }];
|
||||
// Indicate that the operation has a custom parser and printer method.
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ class Torch_TypeWithContainedType<string name, string typeMnemonic> : Torch_Type
|
|||
}];
|
||||
|
||||
let parser = [{
|
||||
if (parser.parseLess())
|
||||
if ($_parser.parseLess())
|
||||
return Type();
|
||||
Type containedType;
|
||||
if ($_parser.parseType(containedType))
|
||||
|
@ -39,7 +39,6 @@ class Torch_TypeWithContainedType<string name, string typeMnemonic> : Torch_Type
|
|||
return Type();
|
||||
return get($_ctxt, containedType);
|
||||
}];
|
||||
|
||||
let builders = [
|
||||
TypeBuilderWithInferredContext<(ins "::mlir::Type":$containedType), [{
|
||||
return Base::get(containedType.getContext(), containedType);
|
||||
|
@ -61,9 +60,9 @@ def Torch_NnModuleType : Torch_Type<"NnModule", "nn.Module"> {
|
|||
}];
|
||||
|
||||
let parser = [{
|
||||
if (parser.parseLess())
|
||||
if ($_parser.parseLess())
|
||||
return Type();
|
||||
std::string className;
|
||||
std::string className;
|
||||
if ($_parser.parseOptionalString(&className))
|
||||
return Type();
|
||||
if ($_parser.parseGreater())
|
||||
|
@ -349,7 +348,7 @@ def Torch_DictType : Torch_Type<"Dict", "dict"> {
|
|||
}];
|
||||
|
||||
let parser = [{
|
||||
if (parser.parseLess())
|
||||
if ($_parser.parseLess())
|
||||
return Type();
|
||||
Type keyType;
|
||||
if ($_parser.parseType(keyType))
|
||||
|
@ -363,7 +362,6 @@ def Torch_DictType : Torch_Type<"Dict", "dict"> {
|
|||
return Type();
|
||||
return get($_ctxt, keyType, valueType);
|
||||
}];
|
||||
|
||||
let builders = [
|
||||
TypeBuilderWithInferredContext<(ins "::mlir::Type":$keyType,
|
||||
"::mlir::Type":$valueType), [{
|
||||
|
|
|
@ -17,7 +17,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
|
|||
include "torch-mlir/Dialect/Torch/IR/TorchTypes.td"
|
||||
include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionBase.td"
|
||||
|
||||
class TorchConversion_Op<string mnemonic, list<OpTrait> traits = []>
|
||||
class TorchConversion_Op<string mnemonic, list<Trait> traits = []>
|
||||
: Op<TorchConversion_Dialect, mnemonic, traits> {
|
||||
}
|
||||
|
||||
|
|
|
@ -25,6 +25,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createMungeCallingConventionsPass();
|
|||
std::unique_ptr<OperationPass<FuncOp>> createExpandOpsForLLVMPass();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createInsertRngGlobalsPass();
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createMungeMemrefCopyPass();
|
||||
} // namespace RefBackend
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
|
|
@ -29,4 +29,10 @@ def ExpandOpsForLLVM : Pass<"refback-expand-ops-for-llvm", "FuncOp"> {
|
|||
let constructor = "mlir::torch::RefBackend::createExpandOpsForLLVMPass();";
|
||||
}
|
||||
|
||||
def MungeMemrefCopy : Pass<"refback-munge-memref-copy", "FuncOp"> {
|
||||
let summary = "Munge memref.copy to linalg.copy";
|
||||
let constructor = "mlir::torch::RefBackend::createMungeMemrefCopyPass();";
|
||||
let dependentDialects = ["memref::MemRefDialect"];
|
||||
}
|
||||
|
||||
#endif // TORCHMLIR_REFBACKEND_PASSES
|
||||
|
|
|
@ -11,8 +11,10 @@
|
|||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tensor/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Traits.h"
|
||||
|
@ -78,7 +80,8 @@ static Value toPositiveDimDynamic(OpBuilder &b, Location loc, Value dim,
|
|||
b.create<arith::ConstantOp>(loc, b.getZeroAttr(inputRank.getType()));
|
||||
Value predDimGEZero =
|
||||
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, dim, cst0);
|
||||
Value dimInt = b.create<SelectOp>(loc, predDimGEZero, dim, dimAddInputRank);
|
||||
Value dimInt =
|
||||
b.create<arith::SelectOp>(loc, predDimGEZero, dim, dimAddInputRank);
|
||||
return dimInt;
|
||||
}
|
||||
|
||||
|
@ -91,12 +94,12 @@ static void assertIsValidDim(OpBuilder &b, Location loc, Value dim,
|
|||
b.create<arith::ConstantOp>(loc, b.getZeroAttr(inputRank.getType()));
|
||||
Value predGEZero =
|
||||
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, dim, cst0);
|
||||
b.create<AssertOp>(loc, predGEZero,
|
||||
b.getStringAttr("dim must be greater or equal to zero"));
|
||||
b.create<cf::AssertOp>(
|
||||
loc, predGEZero, b.getStringAttr("dim must be greater or equal to zero"));
|
||||
Value predLTInputRank =
|
||||
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, dim, inputRank);
|
||||
b.create<AssertOp>(loc, predLTInputRank,
|
||||
b.getStringAttr("dim must be smaller than inputRank"));
|
||||
b.create<cf::AssertOp>(loc, predLTInputRank,
|
||||
b.getStringAttr("dim must be smaller than inputRank"));
|
||||
}
|
||||
|
||||
// Hack to deal with the Torch list type arguments which is not supported end
|
||||
|
@ -147,8 +150,8 @@ static void checkDimEqualHelper(OpBuilder &b, Location loc, Value lhsDim,
|
|||
Value rhsDimInt = rhsType.isIndex() ? castIndexToInt(b, loc, rhsDim) : rhsDim;
|
||||
Value contractingDimEqual = b.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, lhsDimInt, rhsDimInt);
|
||||
b.create<AssertOp>(loc, contractingDimEqual,
|
||||
b.getStringAttr("mismatching contracting dimension"));
|
||||
b.create<cf::AssertOp>(loc, contractingDimEqual,
|
||||
b.getStringAttr("mismatching contracting dimension"));
|
||||
}
|
||||
|
||||
static SmallVector<Value> getTensorSizesUntilDim(OpBuilder &b, Location loc,
|
||||
|
@ -471,8 +474,8 @@ public:
|
|||
rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(intType, 1));
|
||||
Value groupEqual1 = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, groups, c1);
|
||||
rewriter.create<AssertOp>(loc, groupEqual1,
|
||||
rewriter.getStringAttr("expect groups to be 1"));
|
||||
rewriter.create<cf::AssertOp>(
|
||||
loc, groupEqual1, rewriter.getStringAttr("expect groups to be 1"));
|
||||
|
||||
// Pad the input tensor according to padding.
|
||||
SmallVector<int64_t, 4> paddingIncludingNC = {0, 0};
|
||||
|
@ -581,15 +584,16 @@ static void createLinalgPayloadCalculationForGatherOps(
|
|||
Value indexLTInputDim = b.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, index,
|
||||
castIndexToInt(b, loc, getDimOp(b, loc, input, dim)));
|
||||
b.create<AssertOp>(loc, indexLTInputDim,
|
||||
b.getStringAttr("index must be smaller than dim size"));
|
||||
b.create<cf::AssertOp>(
|
||||
loc, indexLTInputDim,
|
||||
b.getStringAttr("index must be smaller than dim size"));
|
||||
|
||||
// Assert index >= 0
|
||||
Value cst0 = b.create<arith::ConstantOp>(loc, b.getZeroAttr(index.getType()));
|
||||
Value indexGEThanZero =
|
||||
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, index, cst0);
|
||||
b.create<AssertOp>(loc, indexGEThanZero,
|
||||
b.getStringAttr("index must be larger or equal to 0"));
|
||||
b.create<cf::AssertOp>(loc, indexGEThanZero,
|
||||
b.getStringAttr("index must be larger or equal to 0"));
|
||||
|
||||
Value extract = b.create<tensor::ExtractOp>(loc, input, indices);
|
||||
b.create<linalg::YieldOp>(loc, extract);
|
||||
|
@ -858,7 +862,7 @@ public:
|
|||
Value rhsDim1 = rewriter.create<tensor::DimOp>(loc, rhs, 1);
|
||||
Value contractingDimEqual = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, lhsDim1, rhsDim0);
|
||||
rewriter.create<AssertOp>(
|
||||
rewriter.create<cf::AssertOp>(
|
||||
loc, contractingDimEqual,
|
||||
rewriter.getStringAttr(
|
||||
"mismatching contracting dimension for torch.aten.mm"));
|
||||
|
@ -1196,7 +1200,7 @@ public:
|
|||
loc, input, ValueRange{indI, indTarget});
|
||||
Value negate =
|
||||
rewriter.create<arith::NegFOp>(loc, elementType, result);
|
||||
Value selectFinal = rewriter.create<mlir::SelectOp>(
|
||||
Value selectFinal = rewriter.create<arith::SelectOp>(
|
||||
loc, cmpEq, zeroVal, negate);
|
||||
b.create<linalg::YieldOp>(loc, selectFinal);
|
||||
})
|
||||
|
@ -1302,7 +1306,7 @@ public:
|
|||
rewriter.create<arith::AndIOp>(loc, cmpEq, cmpNe);
|
||||
Value negate =
|
||||
rewriter.create<arith::NegFOp>(loc, elementType, args[1]);
|
||||
Value selectFinal = rewriter.create<mlir::SelectOp>(
|
||||
Value selectFinal = rewriter.create<arith::SelectOp>(
|
||||
loc, finalPredicate, negate, zeroVal);
|
||||
b.create<linalg::YieldOp>(loc, selectFinal);
|
||||
})
|
||||
|
@ -1376,7 +1380,7 @@ public:
|
|||
Value biasDim0 = getDimOp(rewriter, loc, bias, 0);
|
||||
Value contractingDimEqual = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, inputDim1, weightDim1);
|
||||
rewriter.create<AssertOp>(
|
||||
rewriter.create<cf::AssertOp>(
|
||||
loc, contractingDimEqual,
|
||||
rewriter.getStringAttr(
|
||||
"mismatching contracting dimension for aten.linear"));
|
||||
|
@ -1384,7 +1388,7 @@ public:
|
|||
// In the static-size-1 case, we will not emit this check at all.
|
||||
Value biasSizeCorrect = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, weightDim0, biasDim0);
|
||||
rewriter.create<AssertOp>(
|
||||
rewriter.create<cf::AssertOp>(
|
||||
loc, biasSizeCorrect,
|
||||
rewriter.getStringAttr("mismatching bias size for aten.linear"));
|
||||
|
||||
|
@ -1500,31 +1504,31 @@ static Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar,
|
|||
if (auto dtypeFloat = dtype.dyn_cast<mlir::FloatType>()) {
|
||||
if (auto scalarFloat = scalarType.dyn_cast<mlir::FloatType>()) {
|
||||
if (scalarFloat.getWidth() > dtypeFloat.getWidth())
|
||||
return b.create<arith::TruncFOp>(loc, scalar, dtype);
|
||||
return b.create<arith::TruncFOp>(loc, dtype, scalar);
|
||||
// Only scalarFloat width < dtypeFloat width can reach here.
|
||||
return b.create<arith::ExtFOp>(loc, scalar, dtype);
|
||||
return b.create<arith::ExtFOp>(loc, dtype, scalar);
|
||||
}
|
||||
assert(scalarType.isa<mlir::IntegerType>());
|
||||
if (scalarType.isSignlessInteger(1))
|
||||
return b.create<arith::UIToFPOp>(loc, scalar, dtype);
|
||||
return b.create<arith::UIToFPOp>(loc, dtype, scalar);
|
||||
// It's safe to use SIToFPOp because ui8/si8 are the only ones where
|
||||
// unsigned handling is needed, and we checked for that case above.
|
||||
return b.create<arith::SIToFPOp>(loc, scalar, dtype);
|
||||
return b.create<arith::SIToFPOp>(loc, dtype, scalar);
|
||||
}
|
||||
|
||||
if (auto dtypeInteger = dtype.dyn_cast<mlir::IntegerType>()) {
|
||||
if (auto scalarFloat = scalarType.dyn_cast<mlir::FloatType>())
|
||||
return b.create<arith::FPToSIOp>(loc, scalar, dtype);
|
||||
return b.create<arith::FPToSIOp>(loc, dtype, scalar);
|
||||
assert(scalarType.isa<mlir::IntegerType>());
|
||||
auto scalarInteger = scalarType.cast<mlir::IntegerType>();
|
||||
if (scalarInteger.getWidth() > dtypeInteger.getWidth())
|
||||
return b.create<arith::TruncIOp>(loc, scalar, dtype);
|
||||
return b.create<arith::TruncIOp>(loc, dtype, scalar);
|
||||
if (scalarType.isSignlessInteger(1))
|
||||
return b.create<arith::ExtUIOp>(loc, scalar, dtype);
|
||||
return b.create<arith::ExtUIOp>(loc, dtype, scalar);
|
||||
// Only scalarInteger width < dtypeInteger width can reach here.
|
||||
// It's safe to use ExtSIOp here because ui8/si8 are the only ones where
|
||||
// unsigned handling is needed, and we checked for that case above.
|
||||
return b.create<arith::ExtSIOp>(loc, scalar, dtype);
|
||||
return b.create<arith::ExtSIOp>(loc, dtype, scalar);
|
||||
}
|
||||
|
||||
llvm_unreachable("convertScalarToDtype should handle all the types");
|
||||
|
@ -1595,7 +1599,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
|
||||
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
|
||||
payloadArgs[0], constZero);
|
||||
return b.create<SelectOp>(loc, pred, payloadArgs[0], constZero);
|
||||
return b.create<arith::SelectOp>(loc, pred, payloadArgs[0], constZero);
|
||||
}
|
||||
if (auto lrelu = dyn_cast<AtenLeakyReluOp>(op)) {
|
||||
if (!lrelu.getType()
|
||||
|
@ -1610,8 +1614,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
|
||||
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
|
||||
payloadArgs[0], constZero);
|
||||
Value positivePart = b.create<SelectOp>(loc, pred, payloadArgs[0], constZero);
|
||||
Value negativePart = b.create<SelectOp>(loc, pred, constZero, payloadArgs[0]);
|
||||
Value positivePart =
|
||||
b.create<arith::SelectOp>(loc, pred, payloadArgs[0], constZero);
|
||||
Value negativePart =
|
||||
b.create<arith::SelectOp>(loc, pred, constZero, payloadArgs[0]);
|
||||
Value scale = convertScalarToDtype(b, loc, operands[1], elementType);
|
||||
Value scaledNegativePart = b.create<arith::MulFOp>(loc, negativePart, scale);
|
||||
return b.create<arith::AddFOp>(loc, positivePart, scaledNegativePart);
|
||||
|
@ -1990,7 +1996,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
.getElementType();
|
||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[2], dtype);
|
||||
return b.create<SelectOp>(loc, payloadArgs[0], lhs, rhs);
|
||||
return b.create<arith::SelectOp>(loc, payloadArgs[0], lhs, rhs);
|
||||
}
|
||||
|
||||
if (auto lerp = dyn_cast<AtenLerpTensorOp>(op)) {
|
||||
|
@ -2019,7 +2025,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
|
||||
payloadArgs[0], payloadArgs[1]);
|
||||
return b.create<SelectOp>(loc, pred, payloadArgs[0], payloadArgs[1]);
|
||||
return b.create<arith::SelectOp>(loc, pred, payloadArgs[0], payloadArgs[1]);
|
||||
}
|
||||
if (auto maximum = dyn_cast<AtenMaximumOp>(op)) {
|
||||
if (!maximum.getType()
|
||||
|
@ -2031,7 +2037,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
|
||||
payloadArgs[0], payloadArgs[1]);
|
||||
return b.create<SelectOp>(loc, pred, payloadArgs[0], payloadArgs[1]);
|
||||
return b.create<arith::SelectOp>(loc, pred, payloadArgs[0], payloadArgs[1]);
|
||||
}
|
||||
if (auto clamp = dyn_cast<AtenClampOp>(op)) {
|
||||
Type dtype = converter->convertType(clamp.getType())
|
||||
|
@ -2054,13 +2060,13 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
auto minPromoted = convertScalarToDtype(b, loc, min, dtype);
|
||||
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
|
||||
result, minPromoted);
|
||||
result = b.create<SelectOp>(loc, pred, minPromoted, result);
|
||||
result = b.create<arith::SelectOp>(loc, pred, minPromoted, result);
|
||||
}
|
||||
if (!max.getType().isa<Torch::NoneType>()) {
|
||||
auto maxPromoted = convertScalarToDtype(b, loc, max, dtype);
|
||||
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
|
||||
result, maxPromoted);
|
||||
result = b.create<SelectOp>(loc, pred, maxPromoted, result);
|
||||
result = b.create<arith::SelectOp>(loc, pred, maxPromoted, result);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
@ -2126,7 +2132,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 0.0));
|
||||
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ONE,
|
||||
payloadArgs[0], zero);
|
||||
b.create<AssertOp>(
|
||||
b.create<cf::AssertOp>(
|
||||
loc, pred, b.getStringAttr("unimplemented: tensor with zero element"));
|
||||
|
||||
auto one =
|
||||
|
@ -2152,7 +2158,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
else
|
||||
predicate = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle, self,
|
||||
threshold);
|
||||
return b.create<SelectOp>(loc, predicate, value, self);
|
||||
return b.create<arith::SelectOp>(loc, predicate, value, self);
|
||||
}
|
||||
if (auto thresholdBackward = dyn_cast<AtenThresholdBackwardOp>(op)) {
|
||||
// The approach used here is as follows:
|
||||
|
@ -2174,7 +2180,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
else
|
||||
predicate = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle, self,
|
||||
threshold);
|
||||
return b.create<SelectOp>(loc, predicate, constantZero, grad);
|
||||
return b.create<arith::SelectOp>(loc, predicate, constantZero, grad);
|
||||
}
|
||||
|
||||
op->emitError("unimplemented lowering in "
|
||||
|
@ -2347,9 +2353,9 @@ public:
|
|||
if (inElementType.isa<mlir::FloatType>())
|
||||
predicate = rewriter.create<arith::CmpFOp>(
|
||||
nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
|
||||
auto resultMax = rewriter.create<mlir::SelectOp>(nestedLoc, predicate,
|
||||
newValue, oldValue);
|
||||
auto resultIndex = rewriter.create<mlir::SelectOp>(
|
||||
auto resultMax = rewriter.create<arith::SelectOp>(
|
||||
nestedLoc, predicate, newValue, oldValue);
|
||||
auto resultIndex = rewriter.create<arith::SelectOp>(
|
||||
nestedLoc, predicate, newIndex, oldIndex);
|
||||
nestedBuilder.create<linalg::YieldOp>(
|
||||
nestedLoc, ValueRange({resultMax, resultIndex}));
|
||||
|
@ -2487,8 +2493,8 @@ struct ConvertElementwiseOp : ConversionPattern {
|
|||
auto equalToRunning = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, resultShape[resultDim],
|
||||
currentDimSize);
|
||||
rewriter.create<AssertOp>(loc, equalToRunning,
|
||||
"mismatched size for broadcast");
|
||||
rewriter.create<cf::AssertOp>(loc, equalToRunning,
|
||||
"mismatched size for broadcast");
|
||||
}
|
||||
indexingMaps.push_back(AffineMap::get(
|
||||
/*dimCount=*/resultRank, /*symbolCount=*/0, exprs, getContext()));
|
||||
|
@ -2703,7 +2709,7 @@ public:
|
|||
loc, IntegerAttr::get(rewriter.getIntegerType(1), 0));
|
||||
Value ceilModeFalse = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, ceilMode, falseValue);
|
||||
rewriter.create<AssertOp>(
|
||||
rewriter.create<cf::AssertOp>(
|
||||
loc, ceilModeFalse,
|
||||
rewriter.getStringAttr("only ceil_mode false is supported"));
|
||||
|
||||
|
@ -3155,7 +3161,7 @@ public:
|
|||
Value dimSize = getDimOp(rewriter, loc, input, i);
|
||||
Value dimSizeNotOne = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::ne, dimSize, one);
|
||||
rewriter.create<AssertOp>(
|
||||
rewriter.create<cf::AssertOp>(
|
||||
loc, dimSizeNotOne,
|
||||
rewriter.getStringAttr(
|
||||
"unimplemented: size 1 dynamic dimension is not supported"));
|
||||
|
@ -3489,12 +3495,12 @@ public:
|
|||
loc, rewriter.getZeroAttr(dimSizeAsInt.getType()));
|
||||
Value predDimSltZero = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, startOrEndToPositive, cst0);
|
||||
Value startOrEndAtLeastZero = rewriter.create<SelectOp>(
|
||||
Value startOrEndAtLeastZero = rewriter.create<arith::SelectOp>(
|
||||
loc, predDimSltZero, cst0, startOrEndToPositive);
|
||||
// startOrEnd > dimSizeAsInt ? dimSizeAsInt : startOrEnd
|
||||
Value startOrEndSgtDimSize = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::sgt, startOrEndAtLeastZero, dimSizeAsInt);
|
||||
Value startOrEndBoundedByDimSize = rewriter.create<SelectOp>(
|
||||
Value startOrEndBoundedByDimSize = rewriter.create<arith::SelectOp>(
|
||||
loc, startOrEndSgtDimSize, dimSizeAsInt, startOrEndAtLeastZero);
|
||||
|
||||
return castIntToIndex(rewriter, loc, startOrEndBoundedByDimSize);
|
||||
|
@ -3509,7 +3515,7 @@ public:
|
|||
// end >= start ? end : start
|
||||
Value endSgeStart = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::sge, end, start);
|
||||
end = rewriter.create<SelectOp>(loc, endSgeStart, end, start);
|
||||
end = rewriter.create<arith::SelectOp>(loc, endSgeStart, end, start);
|
||||
|
||||
int64_t step;
|
||||
if (!matchPattern(op.step(), m_TorchConstantInt(&step))) {
|
||||
|
@ -3851,7 +3857,7 @@ public:
|
|||
if (i < diff) {
|
||||
Value isValid = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::sge, shapeValue, zero);
|
||||
rewriter.create<AssertOp>(
|
||||
rewriter.create<cf::AssertOp>(
|
||||
loc, isValid,
|
||||
rewriter.getStringAttr(
|
||||
"negative values not allowed in new dimensions"));
|
||||
|
@ -3864,7 +3870,7 @@ public:
|
|||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
|
||||
Value isNegative = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, shapeValue, zero);
|
||||
Value select = rewriter.create<SelectOp>(
|
||||
Value select = rewriter.create<arith::SelectOp>(
|
||||
loc, isNegative, one, castIntToIndex(rewriter, loc, shapeValue));
|
||||
outShape.push_back(select);
|
||||
outExpr.push_back(mlir::getAffineConstantExpr(0, context));
|
||||
|
@ -3878,7 +3884,7 @@ public:
|
|||
loc, arith::CmpIPredicate::eq, castIndexToInt(rewriter, loc, dim),
|
||||
shapeValue);
|
||||
Value isValid = rewriter.create<arith::OrIOp>(loc, isNegative, isEqual);
|
||||
rewriter.create<AssertOp>(
|
||||
rewriter.create<cf::AssertOp>(
|
||||
loc, isValid,
|
||||
rewriter.getStringAttr(
|
||||
"only broadcasting singleton dimensions supported"));
|
||||
|
@ -4490,6 +4496,7 @@ public:
|
|||
registry.insert<StandardOpsDialect>();
|
||||
registry.insert<tensor::TensorDialect>();
|
||||
registry.insert<arith::ArithmeticDialect>();
|
||||
registry.insert<cf::ControlFlowDialect>();
|
||||
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
||||
}
|
||||
|
||||
|
@ -4497,8 +4504,8 @@ public:
|
|||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
|
||||
math::MathDialect, tensor::TensorDialect,
|
||||
arith::ArithmeticDialect>();
|
||||
cf::ControlFlowDialect, math::MathDialect,
|
||||
tensor::TensorDialect, arith::ArithmeticDialect>();
|
||||
target.addLegalOp<GetNextSeedOp>();
|
||||
|
||||
TypeConverter typeConverter;
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#include "torch-mlir/Conversion/TorchToStd/TorchToStd.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Traits.h"
|
||||
|
@ -52,8 +53,8 @@ public:
|
|||
LogicalResult
|
||||
matchAndRewrite(RuntimeAssertOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<AssertOp>(op, adaptor.condition(),
|
||||
adaptor.message());
|
||||
rewriter.replaceOpWithNewOp<cf::AssertOp>(op, adaptor.condition(),
|
||||
adaptor.message());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -160,6 +161,7 @@ public:
|
|||
registry.insert<StandardOpsDialect>();
|
||||
registry.insert<arith::ArithmeticDialect>();
|
||||
registry.insert<tensor::TensorDialect>();
|
||||
registry.insert<cf::ControlFlowDialect>();
|
||||
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
||||
}
|
||||
|
||||
|
@ -167,7 +169,8 @@ public:
|
|||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<Torch::TorchDialect, StandardOpsDialect,
|
||||
arith::ArithmeticDialect, tensor::TensorDialect>();
|
||||
arith::ArithmeticDialect, tensor::TensorDialect,
|
||||
cf::ControlFlowDialect>();
|
||||
|
||||
TypeConverter typeConverter;
|
||||
typeConverter.addConversion([](Type type) { return type; });
|
||||
|
|
|
@ -290,7 +290,7 @@ PrimLoopConditionOp::getMutableSuccessorOperands(Optional<unsigned> index) {
|
|||
// PrimIfOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parsePrimIfOp(OpAsmParser &parser, OperationState &result) {
|
||||
ParseResult PrimIfOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
// Create the regions.
|
||||
result.regions.reserve(2);
|
||||
Region *thenRegion = result.addRegion();
|
||||
|
@ -319,14 +319,14 @@ static ParseResult parsePrimIfOp(OpAsmParser &parser, OperationState &result) {
|
|||
return success();
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, PrimIfOp op) {
|
||||
p << " " << op.condition();
|
||||
p << " -> (" << op.getResultTypes() << ") ";
|
||||
p.printRegion(op.thenRegion(), /*printEntryBlockArgs=*/false);
|
||||
void PrimIfOp::print(OpAsmPrinter &p) {
|
||||
p << " " << condition();
|
||||
p << " -> (" << getResultTypes() << ") ";
|
||||
p.printRegion(thenRegion(), /*printEntryBlockArgs=*/false);
|
||||
p << " else ";
|
||||
p.printRegion(op.elseRegion(), /*printEntryBlockArgs=*/false);
|
||||
p.printRegion(elseRegion(), /*printEntryBlockArgs=*/false);
|
||||
|
||||
p.printOptionalAttrDict(op->getAttrs());
|
||||
p.printOptionalAttrDict((*this)->getAttrs());
|
||||
}
|
||||
|
||||
void PrimIfOp::getSuccessorRegions(Optional<unsigned> index,
|
||||
|
@ -938,8 +938,7 @@ void ConstantDeviceOp::getAsmResultNames(
|
|||
// ConstantIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseConstantIntOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
ParseResult ConstantIntOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
Builder builder(result.getContext());
|
||||
result.addTypes(builder.getType<Torch::IntType>());
|
||||
if (parser.parseOptionalAttrDict(result.attributes))
|
||||
|
@ -951,10 +950,10 @@ static ParseResult parseConstantIntOp(OpAsmParser &parser,
|
|||
return success();
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, Torch::ConstantIntOp op) {
|
||||
void ConstantIntOp::print(OpAsmPrinter &p) {
|
||||
p << " ";
|
||||
p << op.value().getSExtValue();
|
||||
p.printOptionalAttrDict(op->getAttrs(), {"value"});
|
||||
p << value().getSExtValue();
|
||||
p.printOptionalAttrDict((*this)->getAttrs(), {"value"});
|
||||
}
|
||||
|
||||
OpFoldResult Torch::ConstantIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
|
|
|
@ -9,8 +9,10 @@
|
|||
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
@ -68,9 +70,8 @@ class VerifyLinalgOnTensorsBackendContractPass
|
|||
target.addDynamicallyLegalDialect<linalg::LinalgDialect>(opHasLegalTypes);
|
||||
target.addDynamicallyLegalDialect<tensor::TensorDialect>(opHasLegalTypes);
|
||||
target.addDynamicallyLegalDialect<AffineDialect>(opHasLegalTypes);
|
||||
target.addDynamicallyLegalDialect<cf::ControlFlowDialect>(opHasLegalTypes);
|
||||
|
||||
// AssertOp is used to terminate the program for error guards.
|
||||
target.addLegalOp<AssertOp>();
|
||||
// ConstantOp is used for tensors and for scalars.
|
||||
target.addDynamicallyLegalOp<arith::ConstantOp>(opHasLegalTypes);
|
||||
|
||||
|
|
|
@ -16,12 +16,15 @@
|
|||
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/Math/Transforms/Approximation.h"
|
||||
#include "mlir/Dialect/Math/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
||||
#include "torch-mlir/RefBackend/Passes.h"
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
|
@ -136,7 +139,7 @@ static LogicalResult mungeFunction(
|
|||
if (!isArgMemRefTypeValid(type))
|
||||
return emitError(arg.getLoc(),
|
||||
"argument must be a memref of f32, f64, i32, i64, i1");
|
||||
auto cast = b.create<memref::CastOp>(arg.getLoc(), arg, type);
|
||||
auto cast = b.create<memref::CastOp>(arg.getLoc(), type, arg);
|
||||
arg.replaceAllUsesExcept(cast, cast);
|
||||
arg.setType(getAbiTypeForMemRef(type));
|
||||
newArgTypes.push_back(arg.getType());
|
||||
|
@ -159,7 +162,7 @@ static LogicalResult mungeFunction(
|
|||
// Cast to unranked memref type before sending it as a function
|
||||
// argument.
|
||||
retVal = b.create<memref::CastOp>(
|
||||
op.getLoc(), retVal, getAbiTypeForMemRef(types[en.index()]));
|
||||
op.getLoc(), getAbiTypeForMemRef(types[en.index()]), retVal);
|
||||
}
|
||||
retTypes.push_back(retType);
|
||||
retVals.push_back(retVal);
|
||||
|
@ -354,3 +357,64 @@ std::unique_ptr<OperationPass<FuncOp>>
|
|||
mlir::torch::RefBackend::createExpandOpsForLLVMPass() {
|
||||
return std::make_unique<ExpandOpsForLLVM>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MungeMemrefCopy
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Operation *createLinalgCopyOp(OpBuilder &b, Location loc, Value from,
|
||||
Value to) {
|
||||
auto memrefTypeFrom = from.getType().cast<MemRefType>();
|
||||
auto memrefTypeTo = to.getType().cast<MemRefType>();
|
||||
(void)memrefTypeFrom;
|
||||
assert(memrefTypeFrom && memrefTypeTo &&
|
||||
memrefTypeFrom.getRank() == memrefTypeTo.getRank());
|
||||
AffineMap id =
|
||||
AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext());
|
||||
SmallVector<StringRef> iteratorTypes(memrefTypeTo.getRank(),
|
||||
getParallelIteratorTypeName());
|
||||
return b.create<linalg::GenericOp>(
|
||||
loc,
|
||||
/*inputs=*/from,
|
||||
/*outputs=*/to,
|
||||
/*indexingMaps=*/llvm::makeArrayRef({id, id}),
|
||||
/*iteratorTypes=*/iteratorTypes,
|
||||
[](OpBuilder &b, Location loc, ValueRange args) {
|
||||
b.create<linalg::YieldOp>(loc, args.front());
|
||||
});
|
||||
}
|
||||
|
||||
namespace {
|
||||
class MemrefCopyOpToLinalg : public OpRewritePattern<memref::CopyOp> {
|
||||
using OpRewritePattern<memref::CopyOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(memref::CopyOp copyOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Operation *linalgCopy = createLinalgCopyOp(
|
||||
rewriter, copyOp.getLoc(), copyOp.source(), copyOp.target());
|
||||
rewriter.replaceOp(copyOp, linalgCopy->getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class MungeMemrefCopy : public MungeMemrefCopyBase<MungeMemrefCopy> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<linalg::LinalgDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
RewritePatternSet patterns(&getContext());
|
||||
patterns.insert<MemrefCopyOpToLinalg>(context);
|
||||
if (failed(applyPatternsAndFoldGreedily(getOperation(),
|
||||
std::move(patterns)))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
mlir::torch::RefBackend::createMungeMemrefCopyPass() {
|
||||
return std::make_unique<MungeMemrefCopy>();
|
||||
}
|
||||
|
|
|
@ -165,12 +165,12 @@ class RefBackendInvoker:
|
|||
|
||||
LOWERING_PIPELINE = ",".join([
|
||||
# Bufferize.
|
||||
"tensor-constant-bufferize",
|
||||
"builtin.func(scf-bufferize)",
|
||||
"builtin.func(linalg-bufferize)",
|
||||
"builtin.func(std-bufferize)",
|
||||
"builtin.func(tensor-bufferize)",
|
||||
"builtin.func(refback-munge-memref-copy)",
|
||||
"func-bufferize",
|
||||
"arith-bufferize",
|
||||
"builtin.func(tensor-bufferize)",
|
||||
"builtin.func(finalizing-bufferize)",
|
||||
# Munge to make it ExecutionEngine compatible.
|
||||
# Specifically, we rewrite calling convention boundaries to be in terms
|
||||
|
@ -185,12 +185,15 @@ LOWERING_PIPELINE = ",".join([
|
|||
# Lower to LLVM
|
||||
"builtin.func(convert-linalg-to-loops)",
|
||||
"builtin.func(lower-affine)",
|
||||
"builtin.func(convert-scf-to-std)",
|
||||
"convert-scf-to-cf",
|
||||
"builtin.func(refback-expand-ops-for-llvm)",
|
||||
"builtin.func(arith-expand)",
|
||||
"builtin.func(convert-math-to-llvm)",
|
||||
"convert-linalg-to-llvm",
|
||||
"convert-memref-to-llvm",
|
||||
"builtin.func(convert-arith-to-llvm)",
|
||||
"convert-std-to-llvm",
|
||||
"convert-cf-to-llvm",
|
||||
"reconcile-unrealized-casts",
|
||||
])
|
||||
|
||||
|
|
|
@ -42,7 +42,7 @@ func private @caller(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
|||
// expected-error @+1 {{unimplemented}}
|
||||
func @called(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%ctrue = arith.constant true
|
||||
cond_br %ctrue, ^bb1, ^bb2
|
||||
cf.cond_br %ctrue, ^bb1, ^bb2
|
||||
^bb1:
|
||||
return %arg0 : tensor<*xf32>
|
||||
^bb2:
|
||||
|
|
|
@ -343,7 +343,7 @@ builtin.func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> (!torch.vtensor, !torch.v
|
|||
// CHECK: return %[[CAST]] : !torch.vtensor
|
||||
builtin.func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor {
|
||||
%cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
|
||||
br ^bb1(%cast: !torch.vtensor)
|
||||
cf.br ^bb1(%cast: !torch.vtensor)
|
||||
^bb1(%arg1: !torch.vtensor):
|
||||
%1 = torch.aten.tanh %arg1 : !torch.vtensor -> !torch.vtensor
|
||||
return %1 : !torch.vtensor
|
||||
|
|
|
@ -12,11 +12,11 @@ func @identity(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
|
|||
|
||||
// CHECK-LABEL: func @block_arguments(
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<f32>) -> tensor<f32> {
|
||||
// CHECK: br ^bb1(%[[ARG]] : tensor<f32>)
|
||||
// CHECK: cf.br ^bb1(%[[ARG]] : tensor<f32>)
|
||||
// CHECK: ^bb1(%[[BBARG:.*]]: tensor<f32>):
|
||||
// CHECK: return %[[BBARG]] : tensor<f32>
|
||||
func @block_arguments(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
|
||||
br ^bb1(%arg0: !torch.vtensor<[],f32>)
|
||||
cf.br ^bb1(%arg0: !torch.vtensor<[],f32>)
|
||||
^bb1(%bbarg: !torch.vtensor<[],f32>):
|
||||
return %bbarg : !torch.vtensor<[],f32>
|
||||
}
|
||||
|
@ -55,7 +55,7 @@ func @unconverted_op_in_body() -> !torch.vtensor<[],f32> {
|
|||
// update all terminators and issue an error if that is not possible.
|
||||
func @unable_to_update_terminator(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
|
||||
%0 = arith.constant true
|
||||
cond_br %0, ^bb1(%arg0: !torch.vtensor<[],f32>), ^bb2(%arg0: !torch.vtensor<[],f32>)
|
||||
cf.cond_br %0, ^bb1(%arg0: !torch.vtensor<[],f32>), ^bb2(%arg0: !torch.vtensor<[],f32>)
|
||||
^bb1(%bbarg0: !torch.vtensor<[],f32>):
|
||||
// expected-error @+1 {{failed to legalize operation 'test.terminator'}}
|
||||
"test.terminator"() : () -> ()
|
||||
|
|
|
@ -10,7 +10,7 @@ func @mm(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
|||
%2 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
|
||||
%3 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
|
||||
%4 = arith.cmpi eq, %1, %2 : index
|
||||
assert %4, "mismatching contracting dimension for aten.mm"
|
||||
cf.assert %4, "mismatching contracting dimension for aten.mm"
|
||||
%5 = linalg.init_tensor [%0, %3] : tensor<?x?xf32>
|
||||
%6 = linalg.fill(%cst, %5) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
|
||||
%7 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%6 : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
|
|
Loading…
Reference in New Issue