Update external/llvm-project

- Add `qualified` to ods because of
https://reviews.llvm.org/D113873 and https://reviews.llvm.org/D116905
- Needed to revert https://github.com/llvm/torch-mlir/pull/520 as it
was based on an old torch version.
https://github.com/llvm/torch-mlir/pull/527 will bring this back with
a better design.
- Change ConvertAtenCatOp to use more accurate tensor shape info and
as much static info as possible to pass `tensor.insert_slice`
verification code added by https://reviews.llvm.org/D114715
- Other minor fixes
pull/526/head
dan 2022-01-04 10:00:48 -06:00 committed by Yi Zhang
parent 40efd2cb8e
commit 3745f54489
17 changed files with 332 additions and 324 deletions

@ -1 +1 @@
Subproject commit 966b72098363d44adf2882b9c34fcdbe344ff913
Subproject commit 63f0c00d38ee7879239975a6743d4e6c7847b725

File diff suppressed because it is too large Load Diff

View File

@ -27,7 +27,7 @@ def Torch_PrimLayoutOp : Torch_Op<"prim.layout", [
let results = (outs
Torch_IntType:$result
);
let assemblyFormat = "$a attr-dict `:` type($a) `->` type($result)";
let assemblyFormat = "$a attr-dict `:` qualified(type($a)) `->` qualified(type($result))";
}
def Torch_PrimTupleIndexOp : Torch_Op<"prim.TupleIndex", [
@ -42,7 +42,7 @@ def Torch_PrimTupleIndexOp : Torch_Op<"prim.TupleIndex", [
let results = (outs
AnyTorchType:$result
);
let assemblyFormat = "$tup `,` $i attr-dict `:` type($tup) `,` type($i) `->` type($result)";
let assemblyFormat = "$tup `,` $i attr-dict `:` qualified(type($tup)) `,` qualified(type($i)) `->` qualified(type($result))";
let hasCanonicalizer = 1;
}
@ -57,7 +57,7 @@ def Torch_PrimDeviceOp : Torch_Op<"prim.device", [
let results = (outs
Torch_DeviceType:$result
);
let assemblyFormat = "$a attr-dict `:` type($a) `->` type($result)";
let assemblyFormat = "$a attr-dict `:` qualified(type($a)) `->` qualified(type($result))";
}
def Torch_PrimDtypeOp : Torch_Op<"prim.dtype", [
@ -71,7 +71,7 @@ def Torch_PrimDtypeOp : Torch_Op<"prim.dtype", [
let results = (outs
Torch_IntType:$result
);
let assemblyFormat = "$a attr-dict `:` type($a) `->` type($result)";
let assemblyFormat = "$a attr-dict `:` qualified(type($a)) `->` qualified(type($result))";
let hasFolder = 1;
}
@ -85,7 +85,7 @@ def Torch_PrimTupleUnpackOp : Torch_Op<"prim.TupleUnpack", [
let results = (outs
Variadic<AnyTorchType>:$results
);
let assemblyFormat = "$tup attr-dict `:` type($tup) `->` type($results)";
let assemblyFormat = "$tup attr-dict `:` qualified(type($tup)) `->` qualified(type($results))";
let hasCanonicalizer = 1;
}
@ -100,7 +100,7 @@ def Torch_PrimNumToTensorScalarOp : Torch_Op<"prim.NumToTensor.Scalar", [
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$a attr-dict `:` type($a) `->` type($result)";
let assemblyFormat = "$a attr-dict `:` qualified(type($a)) `->` qualified(type($result))";
}
def Torch_PrimMinSelfIntOp : Torch_Op<"prim.min.self_int", [
@ -114,7 +114,7 @@ def Torch_PrimMinSelfIntOp : Torch_Op<"prim.min.self_int", [
let results = (outs
Torch_IntType:$result
);
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
}
def Torch_PrimMinIntOp : Torch_Op<"prim.min.int", [
@ -129,7 +129,7 @@ def Torch_PrimMinIntOp : Torch_Op<"prim.min.int", [
let results = (outs
Torch_IntType:$result
);
let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)";
let assemblyFormat = "$a `,` $b attr-dict `:` qualified(type($a)) `,` qualified(type($b)) `->` qualified(type($result))";
}
def Torch_PrimMaxSelfIntOp : Torch_Op<"prim.max.self_int", [
@ -143,7 +143,7 @@ def Torch_PrimMaxSelfIntOp : Torch_Op<"prim.max.self_int", [
let results = (outs
Torch_IntType:$result
);
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
}
def Torch_PrimMaxIntOp : Torch_Op<"prim.max.int", [
@ -158,7 +158,7 @@ def Torch_PrimMaxIntOp : Torch_Op<"prim.max.int", [
let results = (outs
Torch_IntType:$result
);
let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)";
let assemblyFormat = "$a `,` $b attr-dict `:` qualified(type($a)) `,` qualified(type($b)) `->` qualified(type($result))";
}
def Torch_PrimRaiseExceptionOp : Torch_Op<"prim.RaiseException", [
@ -171,7 +171,7 @@ def Torch_PrimRaiseExceptionOp : Torch_Op<"prim.RaiseException", [
);
let results = (outs
);
let assemblyFormat = "$msg attr-dict `:` type($msg)";
let assemblyFormat = "$msg attr-dict `:` qualified(type($msg))";
}
def Torch_PrimUninitializedOp : Torch_Op<"prim.Uninitialized", [
@ -184,7 +184,7 @@ def Torch_PrimUninitializedOp : Torch_Op<"prim.Uninitialized", [
let results = (outs
AnyTorchType:$result
);
let assemblyFormat = " attr-dict `:` type($result)";
let assemblyFormat = " attr-dict `:` qualified(type($result))";
let hasCanonicalizer = 1;
}
@ -199,7 +199,7 @@ def Torch_PrimUncheckedCastOp : Torch_Op<"prim.unchecked_cast", [
let results = (outs
AnyTorchType:$result
);
let assemblyFormat = "$x attr-dict `:` type($x) `->` type($result)";
let assemblyFormat = "$x attr-dict `:` qualified(type($x)) `->` qualified(type($result))";
}
def Torch_PrimPrintOp : Torch_Op<"prim.Print", [
@ -211,7 +211,7 @@ def Torch_PrimPrintOp : Torch_Op<"prim.Print", [
);
let results = (outs
);
let assemblyFormat = "`(` $operands `)` attr-dict `:` type($operands)";
let assemblyFormat = "`(` $operands `)` attr-dict `:` qualified(type($operands))";
}
def Torch_PrimTolistOp : Torch_Op<"prim.tolist", [
@ -224,6 +224,6 @@ def Torch_PrimTolistOp : Torch_Op<"prim.tolist", [
let results = (outs
Variadic<AnyTorchType>:$results
);
let assemblyFormat = "`(` $operands `)` attr-dict `:` type($operands) `->` type($results)";
let assemblyFormat = "`(` $operands `)` attr-dict `:` qualified(type($operands)) `->` qualified(type($results))";
}

View File

@ -31,6 +31,6 @@ def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [
let results = (outs
AnyTorchTensorType:$Y
);
let assemblyFormat = "$X `,` $W_prepack `,` $Y_scale_i `,` $Y_zero_point_i attr-dict `:` type($X) `,` type($W_prepack) `,` type($Y_scale_i) `,` type($Y_zero_point_i) `->` type($Y)";
let assemblyFormat = "$X `,` $W_prepack `,` $Y_scale_i `,` $Y_zero_point_i attr-dict `:` qualified(type($X)) `,` qualified(type($W_prepack)) `,` qualified(type($Y_scale_i)) `,` qualified(type($Y_zero_point_i)) `->` qualified(type($Y))";
}

View File

@ -64,7 +64,7 @@ def Torch_NnModuleOp : Torch_Op<"nn_module", [
let regions = (region SizedRegion<1>:$region);
let verifier = "return ::verify(*this);";
let assemblyFormat = "$region attr-dict `:` type($result)";
let assemblyFormat = "$region attr-dict `:` qualified(type($result))";
let extraClassDeclaration = [{
StringRef getClassName() { return getType().getClassName(); }
@ -97,7 +97,7 @@ def Torch_SlotOp : Torch_Op<"slot", [
let results = (outs);
let assemblyFormat = [{
$name `,` $value attr-dict `:` type($value)
$name `,` $value attr-dict `:` qualified(type($value))
}];
}
@ -272,7 +272,7 @@ def Torch_GlobalSlotInitOp : Torch_Op<"global_slot.init", [
// TODO: Have a SingleBlockExplicitTerminator trait.
let builders = [OpBuilder<(ins), [{ /*nothing to do */ }]>];
let assemblyFormat = "$initialValue attr-dict `:` type($initialValue)";
let assemblyFormat = "$initialValue attr-dict `:` qualified(type($initialValue))";
}
def Torch_GlobalSlotGetOp : Torch_Op<"global_slot.get", []> {
@ -284,7 +284,7 @@ def Torch_GlobalSlotGetOp : Torch_Op<"global_slot.get", []> {
let results = (outs AnyTorchType:$result);
let assemblyFormat = [{
$slot attr-dict `:` type($result)
$slot attr-dict `:` qualified(type($result))
}];
}
@ -298,7 +298,7 @@ def Torch_GlobalSlotSetOp : Torch_Op<"global_slot.set", []> {
let results = (outs);
let assemblyFormat = [{
$slot `=` $value attr-dict `:` type($value)
$slot `=` $value attr-dict `:` qualified(type($value))
}];
}
@ -319,7 +319,7 @@ def Torch_PrimListUnpackOp: Torch_Op<"prim.ListUnpack",
let results = (outs Variadic<AnyTorchType>:$results);
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `->` type($results)
$operand attr-dict `:` qualified(type($operand)) `->` qualified(type($results))
}];
}
@ -343,7 +343,7 @@ def Torch_PrimTupleConstructOp: Torch_Op<"prim.TupleConstruct", [
);
let assemblyFormat = [{
$elements attr-dict `:` type($elements) `->` type($result)
$elements attr-dict `:` qualified(type($elements)) `->` qualified(type($result))
}];
}
@ -385,7 +385,7 @@ def Torch_PrimDictConstructOp: Torch_Op<"prim.DictConstruct", [
let verifier = "return ::verify(*this);";
let assemblyFormat = [{
`keys` `(` ($keys^ `:` type($keys))? `)` `values` `(` ($values^ `:` type($values))? `)` attr-dict `->` type($result)
`keys` `(` ($keys^ `:` qualified(type($keys)))? `)` `values` `(` ($values^ `:` qualified(type($values)))? `)` attr-dict `->` qualified(type($result))
}];
let extraClassDeclaration = [{
@ -401,7 +401,7 @@ def Torch_PrimGetAttrOp : Torch_Op<"prim.GetAttr", []> {
let results = (outs AnyTorchType:$result);
let assemblyFormat = [{
$receiver `[` $name `]` attr-dict `:` type($receiver) `->` type($result)
$receiver `[` $name `]` attr-dict `:` qualified(type($receiver)) `->` qualified(type($result))
}];
}
@ -416,7 +416,7 @@ def Torch_PrimSetAttrOp : Torch_Op<"prim.SetAttr", []> {
let results = (outs);
let assemblyFormat = [{
$receiver `[` $name `]` `=` $value attr-dict `:` type($receiver) `,` type($value)
$receiver `[` $name `]` `=` $value attr-dict `:` qualified(type($receiver)) `,` qualified(type($value))
}];
}
@ -431,7 +431,7 @@ def Torch_PrimCallMethodOp : Torch_Op<"prim.CallMethod", []> {
let results = (outs AnyTorchType:$result);
let assemblyFormat = [{
$receiver `[` $name `]` `(` $operands `)` attr-dict `:` type($receiver) `,` functional-type($operands, $result)
$receiver `[` $name `]` `(` $operands `)` attr-dict `:` qualified(type($receiver)) `,` functional-type($operands, $result)
}];
}
@ -478,7 +478,7 @@ def Torch_PrimLoopConditionOp : Torch_Op<"prim.Loop.condition", [
let assemblyFormat = [{
$shouldContinue `,`
`iter` `(` ($iterArgs^ `:` type($iterArgs))? `)` attr-dict
`iter` `(` ($iterArgs^ `:` qualified(type($iterArgs)))? `)` attr-dict
}];
}
@ -525,7 +525,7 @@ def Torch_PrimIfYieldOp : Torch_Op<"prim.If.yield", [
let results = (outs);
let assemblyFormat = [{
attr-dict ($results^ `:` type($results))?
attr-dict ($results^ `:` qualified(type($results)))?
}];
}
@ -662,7 +662,7 @@ def Torch_DerefineOp : Torch_Op<"derefine", [
let results = (outs AnyTorchType:$result);
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `to` type($result)
$operand attr-dict `:` qualified(type($operand)) `to` qualified(type($result))
}];
}
@ -698,7 +698,7 @@ def Torch_LinearParamsCreateOp : Torch_Op<"linear_params.create", [
let results = (outs Torch_LinearParamsType:$result);
let assemblyFormat = [{
$weight (`,` $bias^)? attr-dict `:` type($weight) (`,` type($bias)^)?
$weight (`,` $bias^)? attr-dict `:` qualified(type($weight)) (`,` qualified(type($bias))^)?
}];
}
@ -727,7 +727,7 @@ def Torch_PerTensorAffineCreateOp : Torch_Op<"per_tensor_affine.create", [
let assemblyFormat = [{
$int_repr `,` $scale `,` $offset attr-dict
`:` type($int_repr) `,` type($scale) `,` type($offset) `->` type($result)
`:` qualified(type($int_repr)) `,` qualified(type($scale)) `,` qualified(type($offset)) `->` qualified(type($result))
}];
}
@ -754,7 +754,7 @@ def Torch_NonValueTensorLiteralOp : Torch_Op<"tensor.literal", [
let results = (outs Torch_NonValueTensorType:$result);
let assemblyFormat = [{
`(` $value `)` attr-dict `:` type($result)
`(` $value `)` attr-dict `:` qualified(type($result))
}];
let extraClassDeclaration = [{
@ -786,7 +786,7 @@ def Torch_ValueTensorLiteralOp : Torch_Op<"vtensor.literal", [
let results = (outs Torch_ValueTensorType:$result);
let assemblyFormat = [{
`(` $value `)` attr-dict `:` type($result)
`(` $value `)` attr-dict `:` qualified(type($result))
}];
let hasFolder = 1;
@ -817,7 +817,7 @@ def Torch_TensorStaticInfoCastOp : Torch_Op<"tensor_static_info_cast", [
AnyTorchTensorType:$result
);
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `to` type($result)
$operand attr-dict `:` qualified(type($operand)) `to` qualified(type($result))
}];
let hasCanonicalizer = 1;
}
@ -849,7 +849,7 @@ def Torch_CopyToNonValueTensorOp : Torch_Op<"copy.to_tensor", [
Torch_NonValueTensorType:$result
);
let assemblyFormat = [{
$operand attr-dict `:` type($result)
$operand attr-dict `:` qualified(type($result))
}];
let verifier = "return ::verify(*this);";
}
@ -879,7 +879,7 @@ def Torch_CopyToValueTensorOp : Torch_Op<"copy.to_vtensor", [
Torch_ValueTensorType:$result
);
let assemblyFormat = [{
$operand attr-dict `:` type($result)
$operand attr-dict `:` qualified(type($result))
}];
let verifier = "return ::verify(*this);";
}
@ -908,7 +908,7 @@ def Torch_OverwriteTensorOp : Torch_Op<"overwrite.tensor", [
);
let assemblyFormat = [{
$value `overwrites` $overwritten attr-dict
`:` type($value) `,` type($overwritten)
`:` qualified(type($value)) `,` qualified(type($overwritten))
}];
}

View File

@ -40,7 +40,7 @@ def TorchConversion_ToBuiltinTensorOp : TorchConversion_Op<"to_builtin_tensor",
AnyTensor:$result
);
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `->` type($result)
$operand attr-dict `:` qualified(type($operand)) `->` qualified(type($result))
}];
}
@ -58,7 +58,7 @@ def TorchConversion_FromBuiltinTensorOp : TorchConversion_Op<"from_builtin_tenso
Torch_ValueTensorType:$result
);
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `->` type($result)
$operand attr-dict `:` qualified(type($operand)) `->` qualified(type($result))
}];
}

View File

@ -11,7 +11,7 @@
#include "../PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Traits.h"
@ -134,12 +134,7 @@ static Value castIndexToInt(OpBuilder &b, Location loc, Value idx) {
}
static Value getDimOp(OpBuilder &b, Location loc, Value v, int dim) {
if (auto tensorType = v.getType().cast<RankedTensorType>()) {
if (!tensorType.isDynamicDim(dim))
return b.create<arith::ConstantOp>(
loc, b.getIndexAttr(tensorType.getShape()[dim]));
}
return b.create<tensor::DimOp>(loc, v, dim);
return b.createOrFold<tensor::DimOp>(loc, v, dim);
}
static void checkDimEqualHelper(OpBuilder &b, Location loc, Value lhsDim,
@ -2202,8 +2197,8 @@ public:
SmallVector<ReassociationIndices> reassociation(1);
for (auto i : llvm::seq<int64_t>(0, inputType.getRank()))
reassociation[0].push_back(i);
input = rewriter.create<linalg::TensorCollapseShapeOp>(
argmaxOp->getLoc(), input, reassociation);
input = rewriter.create<tensor::CollapseShapeOp>(argmaxOp->getLoc(),
input, reassociation);
// Becomes 0 for flattened tensor.
dim = 0;
// Recast to fix shape.
@ -2780,7 +2775,7 @@ public:
if (!(startDim >= -1 && startDim <= 0 && endDim >= -1 && endDim <= 0))
return rewriter.notifyMatchFailure(
op, "start_dim and end_dim must be in [-1, 0] when inputRank is 0");
rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
op, resultType, adaptor.self(), reassociation);
return success();
}
@ -2797,7 +2792,7 @@ public:
if (i < startDim || i >= endDim)
j++;
}
Value collapsedTensor = rewriter.create<linalg::TensorCollapseShapeOp>(
Value collapsedTensor = rewriter.create<tensor::CollapseShapeOp>(
op->getLoc(), adaptor.self(), reassociation);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
collapsedTensor);
@ -3022,12 +3017,12 @@ public:
Value result =
isCollapse
? rewriter
.create<linalg::TensorCollapseShapeOp>(
loc, adjustedResultType, castedInput, reassociation)
.create<tensor::CollapseShapeOp>(loc, adjustedResultType,
castedInput, reassociation)
.result()
: rewriter
.create<linalg::TensorExpandShapeOp>(
loc, adjustedResultType, castedInput, reassociation)
.create<tensor::ExpandShapeOp>(loc, adjustedResultType,
castedInput, reassociation)
.result();
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
return success();
@ -3062,7 +3057,7 @@ public:
// being unit extent, it will be collapsed to a 0-D tensor.
if (resultRank == 0) {
SmallVector<ReassociationIndices> reassociation;
rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>(
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
op, resultType, input, reassociation);
return success();
}
@ -3113,7 +3108,7 @@ public:
op, "expected output size mismatches with the result type rank");
if (isSqueezed) {
rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>(
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
op, resultType, input, reassociation);
} else {
@ -3189,8 +3184,8 @@ public:
// Note: In case the operand tensor type is of unit rank and is statically
// shaped with unit dimension, the `reassociationMap` will be empty and the
// input will be collapsed to a 0-D tensor.
rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>(
op, resultType, input, reassociationMap);
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(op, resultType, input,
reassociationMap);
return success();
}
};
@ -3238,7 +3233,7 @@ public:
auto resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
op, resultType, adaptor.self(), reassociationMap);
return success();
}
@ -3483,8 +3478,8 @@ public:
if (i != dim)
resultIdx++;
}
result = rewriter.create<linalg::TensorCollapseShapeOp>(loc, result,
reassociation);
result =
rewriter.create<tensor::CollapseShapeOp>(loc, result, reassociation);
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
@ -3528,27 +3523,39 @@ public:
offsets.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 0));
for (int i = 0; i < rank; ++i)
sizes.push_back(rewriter.create<tensor::DimOp>(loc, tensors[0], i));
sizes.push_back(rewriter.createOrFold<tensor::DimOp>(loc, tensors[0], i));
// Calculate the size of the `dim` result dimension by adding the dim size
// of each tensor together.
Value resultDimSize = sizes[dim];
Value dimIndex = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), adaptor.dim());
Value dimIndex = rewriter.createOrFold<arith::ConstantOp>(
loc, rewriter.getIndexAttr(dim));
for (auto tensor : makeArrayRef(tensors).drop_front()) {
auto size = rewriter.create<tensor::DimOp>(loc, tensor, dimIndex);
resultDimSize = rewriter.create<arith::AddIOp>(loc, resultDimSize, size);
auto size = rewriter.createOrFold<tensor::DimOp>(loc, tensor, dimIndex);
resultDimSize =
rewriter.createOrFold<arith::AddIOp>(loc, resultDimSize, size);
}
sizes[dim] = resultDimSize;
auto toOpFoldResult = [](Value v) -> OpFoldResult {
auto op = v.getDefiningOp<arith::ConstantIndexOp>();
if (!op)
return v;
return op.getValue();
};
Value result = rewriter.create<linalg::InitTensorOp>(
loc, sizes, newResultType.getElementType());
for (auto tensor : tensors) {
sizes[dim] = rewriter.create<tensor::DimOp>(loc, tensor, dimIndex);
result = rewriter.create<tensor::InsertSliceOp>(loc, tensor, result,
offsets, sizes, strides);
SmallVector<Value> sizes = getTensorSizes(rewriter, loc, tensor);
result = rewriter.createOrFold<tensor::InsertSliceOp>(
loc, tensor, result,
llvm::to_vector(llvm::map_range(offsets, toOpFoldResult)),
llvm::to_vector(llvm::map_range(sizes, toOpFoldResult)),
llvm::to_vector(llvm::map_range(strides, toOpFoldResult)));
offsets[dim] =
rewriter.create<arith::AddIOp>(loc, offsets[dim], sizes[dim]);
rewriter.createOrFold<arith::AddIOp>(loc, offsets[dim], sizes[dim]);
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, result);

View File

@ -53,8 +53,8 @@ public:
rewriter.inlineRegionBefore(srcRegion, dstRegion, dstRegion.begin());
rewriter.eraseBlock(&dstRegion.back());
};
inlineIfCase(op.thenRegion(), scfIf.thenRegion());
inlineIfCase(op.elseRegion(), scfIf.elseRegion());
inlineIfCase(op.thenRegion(), scfIf.getThenRegion());
inlineIfCase(op.elseRegion(), scfIf.getElseRegion());
rewriter.replaceOp(op, scfIf.getResults());
return success();
}

View File

@ -11,6 +11,7 @@
#include "../PassDetail.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
@ -36,7 +37,7 @@ public:
LogicalResult
matchAndRewrite(AtenDimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto rank = rewriter.create<RankOp>(op->getLoc(), adaptor.self());
auto rank = rewriter.create<tensor::RankOp>(op->getLoc(), adaptor.self());
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
op, getTypeConverter()->convertType(op.getType()), rank);
return success();
@ -155,6 +156,7 @@ public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<StandardOpsDialect>();
registry.insert<arith::ArithmeticDialect>();
registry.insert<tensor::TensorDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry);
}
@ -162,7 +164,7 @@ public:
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<Torch::TorchDialect, StandardOpsDialect,
arith::ArithmeticDialect>();
arith::ArithmeticDialect, tensor::TensorDialect>();
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });

View File

@ -111,7 +111,7 @@ public:
for (auto operand : llvm::enumerate(adaptor.getOperands())) {
if (operand.value().getType().isa<Torch::NoneType>())
continue;
auto it = typeBoundMap.find({call.callee(), operand.index()});
auto it = typeBoundMap.find({call.getCallee(), operand.index()});
if (it != typeBoundMap.end()) {
if (auto valueTensorType = it->second.dyn_cast<ValueTensorType>()) {
newOperands.push_back(copyTensorToType(
@ -126,7 +126,7 @@ public:
newOperands.push_back(operand.value());
}
CallOp newCall = rewriter.create<CallOp>(call.getLoc(), call.callee(),
CallOp newCall = rewriter.create<CallOp>(call.getLoc(), call.getCallee(),
convertedResults, newOperands);
int newOpResultIdx = 0;
SmallVector<Value> newResults;

View File

@ -351,7 +351,7 @@ static LogicalResult analyzeInstances(FuncOp func,
static FailureOr<Monomorphization>
createMonomorphizationForCall(CallOp op, BlockAndValueMapping &mapping,
SymbolTable &symbolTable) {
auto func = symbolTable.lookup<FuncOp>(op.callee());
auto func = symbolTable.lookup<FuncOp>(op.getCallee());
Monomorphization monomorphization;
monomorphization.func = func;
for (auto operand : llvm::enumerate(op->getOperands())) {

View File

@ -9,7 +9,7 @@
#include "PassDetail.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/OpDefinition.h"

View File

@ -333,17 +333,17 @@ def raw_emit_op(operator: JitOperator, f: TextIO, *, traits: List[str],
if operator.is_vararg:
assembly_operands = "`(` $operands `)`"
assembly_operand_types = "type($operands)"
assembly_operand_types = "qualified(type($operands))"
else:
assembly_operands = " `,` ".join("$" + arg["name"]
for arg in operator.arguments)
assembly_operand_types = " `,` ".join(
f"""type(${arg["name"]})""" for arg in operator.arguments)
f"""qualified(type(${arg["name"]}))""" for arg in operator.arguments)
if operator.is_varret:
assembly_result_types = "type($results)"
assembly_result_types = "qualified(type($results))"
else:
assembly_result_types = " `,` ".join(
f"""type(${ret["name"] or generic_result_name(e)})"""
f"""qualified(type(${ret["name"] or generic_result_name(e)}))"""
for e, ret in enumerate(operator.returns))
if assembly_operand_types and assembly_result_types:
maybe_arrow = " `->` "

View File

@ -194,7 +194,6 @@ MlirType torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
return torchMlirTorchNoneTypeGet(context);
}
case TypeKind::AnyType: {
auto anyType = torchType->cast<c10::AnyType>();
return torchMlirTorchAnyTypeGet(context);
}
case TypeKind::ClassType: {

View File

@ -5,7 +5,7 @@
// CHECK-LABEL: func @torch.aten.flatten.using_ints$basic(
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2,3,3,5],f32> -> tensor<3x3x2x2x3x3x5xf32>
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3, 4], [5], [6]] : tensor<3x3x2x2x3x3x5xf32> into tensor<3x3x12x3x5xf32>
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3, 4], [5], [6]] : tensor<3x3x2x2x3x3x5xf32> into tensor<3x3x12x3x5xf32>
// CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[COLLAPSED]] : tensor<3x3x12x3x5xf32> to tensor<3x3x?x3x5xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[DYNAMIC]] : tensor<3x3x?x3x5xf32> -> !torch.vtensor<[3,3,?,3,5],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[3,3,?,3,5],f32>
@ -22,7 +22,7 @@ func @torch.aten.flatten.using_ints$basic(%arg0: !torch.vtensor<[3,3,2,2,3,3,5],
// CHECK-LABEL: func @torch.aten.flatten.using_ints$basic_negative(
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2,3,3,5],f32> -> tensor<3x3x2x2x3x3x5xf32>
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3, 4], [5], [6]] : tensor<3x3x2x2x3x3x5xf32> into tensor<3x3x12x3x5xf32>
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3, 4], [5], [6]] : tensor<3x3x2x2x3x3x5xf32> into tensor<3x3x12x3x5xf32>
// CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[COLLAPSED]] : tensor<3x3x12x3x5xf32> to tensor<3x3x?x3x5xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[DYNAMIC]] : tensor<3x3x?x3x5xf32> -> !torch.vtensor<[3,3,?,3,5],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[3,3,?,3,5],f32>
@ -39,7 +39,7 @@ func @torch.aten.flatten.using_ints$basic_negative(%arg0: !torch.vtensor<[3,3,2,
// CHECK-LABEL: func @torch.aten.flatten.using_ints$flatten_front(
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2],f32> -> tensor<3x3x2x2xf32>
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1, 2], [3]] : tensor<3x3x2x2xf32> into tensor<18x2xf32>
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1, 2], [3]] : tensor<3x3x2x2xf32> into tensor<18x2xf32>
// CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[COLLAPSED]] : tensor<18x2xf32> to tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[DYNAMIC]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
@ -56,7 +56,7 @@ func @torch.aten.flatten.using_ints$flatten_front(%arg0: !torch.vtensor<[3,3,2,2
// CHECK-LABEL: func @torch.aten.flatten.using_ints$flatten_back(
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,12],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2],f32> -> tensor<3x3x2x2xf32>
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1, 2, 3]] : tensor<3x3x2x2xf32> into tensor<3x12xf32>
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1, 2, 3]] : tensor<3x3x2x2xf32> into tensor<3x12xf32>
// CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[COLLAPSED]] : tensor<3x12xf32> to tensor<?x12xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[DYNAMIC]] : tensor<?x12xf32> -> !torch.vtensor<[?,12],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,12],f32>
@ -73,7 +73,7 @@ func @torch.aten.flatten.using_ints$flatten_back(%arg0: !torch.vtensor<[3,3,2,2]
// CHECK-LABEL: func @torch.aten.flatten.using_ints$rank0(
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32>
// CHECK: %[[COLLAPSED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[COLLAPSED]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32>

View File

@ -6,7 +6,7 @@
// CHECK-LABEL: func @torch.aten.unsqueeze$basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: %[[EXPANDED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32>
func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
@ -18,7 +18,7 @@ func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[],f32>) -> !torch.vtenso
// CHECK-LABEL: func @torch.aten.unsqueeze$basic_negative(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: %[[EXPANDED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32>
func @torch.aten.unsqueeze$basic_negative(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
@ -30,7 +30,7 @@ func @torch.aten.unsqueeze$basic_negative(%arg0: !torch.vtensor<[],f32>) -> !tor
// CHECK-LABEL: func @torch.aten.unsqueeze$higher_rank_front(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[1,2,3,4],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32>
// CHECK: %[[EXPANDED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1], [2], [3]] : tensor<2x3x4xf32> into tensor<1x2x3x4xf32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1], [2], [3]] : tensor<2x3x4xf32> into tensor<1x2x3x4xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<1x2x3x4xf32> -> !torch.vtensor<[1,2,3,4],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1,2,3,4],f32>
func @torch.aten.unsqueeze$higher_rank_front(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[1,2,3,4],f32> {
@ -42,7 +42,7 @@ func @torch.aten.unsqueeze$higher_rank_front(%arg0: !torch.vtensor<[2,3,4],f32>)
// CHECK-LABEL: func @torch.aten.unsqueeze$higher_rank_back(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,1],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32>
// CHECK: %[[EXPANDED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] : tensor<2x3x4xf32> into tensor<2x3x4x1xf32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] : tensor<2x3x4xf32> into tensor<2x3x4x1xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<2x3x4x1xf32> -> !torch.vtensor<[2,3,4,1],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,4,1],f32>
func @torch.aten.unsqueeze$higher_rank_back(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,1],f32> {
@ -54,7 +54,7 @@ func @torch.aten.unsqueeze$higher_rank_back(%arg0: !torch.vtensor<[2,3,4],f32>)
// CHECK-LABEL: func @torch.aten.unsqueeze$higher_rank_middle(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1,4],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32>
// CHECK: %[[EXPANDED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] : tensor<2x3x4xf32> into tensor<2x3x1x4xf32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] : tensor<2x3x4xf32> into tensor<2x3x1x4xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<2x3x1x4xf32> -> !torch.vtensor<[2,3,1,4],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,1,4],f32>
func @torch.aten.unsqueeze$higher_rank_middle(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1,4],f32> {

View File

@ -4,7 +4,7 @@
// CHECK-LABEL: func @torch.aten.dim(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<*,f32>) -> !torch.int {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<*,f32> -> tensor<*xf32>
// CHECK: %[[RANK:.*]] = rank %[[BUILTIN_TENSOR]] : tensor<*xf32>
// CHECK: %[[RANK:.*]] = tensor.rank %[[BUILTIN_TENSOR]] : tensor<*xf32>
// CHECK: %[[RANK_I64:.*]] = arith.index_cast %[[RANK]] : index to i64
// CHECK: %[[RANK_TORCH_INT:.*]] = torch_c.from_i64 %[[RANK_I64]]
// CHECK: return %[[RANK_TORCH_INT]] : !torch.int