mirror of https://github.com/llvm/torch-mlir
Add folder for aten.broadcast_to on unchanged static shapes (#2421)
parent
34a0897e1b
commit
1fc4314b62
|
@ -8356,6 +8356,7 @@ def Torch_AtenBroadcastToOp : Torch_Op<"aten.broadcast_to", [
|
||||||
printDefaultTorchOp(printer, *this, 2, 1);
|
printDefaultTorchOp(printer, *this, 2, 1);
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_AtenIndexTensorOp : Torch_Op<"aten.index.Tensor", [
|
def Torch_AtenIndexTensorOp : Torch_Op<"aten.index.Tensor", [
|
||||||
|
|
|
@ -2317,6 +2317,25 @@ OpFoldResult AtenStackOp::fold(FoldAdaptor adaptor) {
|
||||||
return list.getElements()[0];
|
return list.getElements()[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AtenBroadcastToOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) {
|
||||||
|
auto inType = getOperand(0).getType().dyn_cast<BaseTensorType>();
|
||||||
|
auto outType = getResult().getType().dyn_cast<BaseTensorType>();
|
||||||
|
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes())
|
||||||
|
return nullptr;
|
||||||
|
if (inType.getSizes().size() != outType.getSizes().size() ||
|
||||||
|
!inType.areAllSizesKnown() || !outType.areAllSizesKnown())
|
||||||
|
return nullptr;
|
||||||
|
for (size_t i = 0; i < inType.getSizes().size(); ++i) {
|
||||||
|
if (inType.getSizes()[i] != outType.getSizes()[i])
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return getOperand(0);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AtenSliceTensorOp
|
// AtenSliceTensorOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -542,7 +542,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)")
|
emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||||
emit("aten::expand : (Tensor, int[], bool) -> (Tensor)")
|
emit("aten::expand : (Tensor, int[], bool) -> (Tensor)")
|
||||||
emit("aten::expand_as : (Tensor, Tensor) -> (Tensor)")
|
emit("aten::expand_as : (Tensor, Tensor) -> (Tensor)")
|
||||||
emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)")
|
emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)", has_folder=True)
|
||||||
emit("aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)")
|
emit("aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)")
|
||||||
emit("aten::index.Tensor_hacked_twin : (Tensor, Tensor[]) -> (Tensor)")
|
emit("aten::index.Tensor_hacked_twin : (Tensor, Tensor[]) -> (Tensor)")
|
||||||
emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)")
|
emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)")
|
||||||
|
|
|
@ -1971,6 +1971,18 @@ func.func @torch.aten.cat$fold_single_operand(%arg0: !torch.tensor) -> !torch.te
|
||||||
return %1: !torch.tensor
|
return %1: !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.broadcast_to$fold(
|
||||||
|
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,4,2],f32> {
|
||||||
|
// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[3,4,2],f32>
|
||||||
|
func.func @torch.aten.broadcast_to$fold(%arg0: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,4,2],f32> {
|
||||||
|
%int3 = torch.constant.int 3
|
||||||
|
%int4 = torch.constant.int 4
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%list = torch.prim.ListConstruct %int3, %int4, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%0 = torch.aten.broadcast_to %arg0, %list : !torch.vtensor<[3,4,2],f32>, !torch.list<int> -> !torch.vtensor<[3,4,2],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,4,2],f32>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @torch.aten.slice.tensor$fold_full_domain_slice
|
// CHECK-LABEL: @torch.aten.slice.tensor$fold_full_domain_slice
|
||||||
// CHECK-SAME: %[[ARG0:.+]]: !torch.vtensor<[4],f32>
|
// CHECK-SAME: %[[ARG0:.+]]: !torch.vtensor<[4],f32>
|
||||||
// CHECK: return %[[ARG0]] : !torch.vtensor<[4],f32>
|
// CHECK: return %[[ARG0]] : !torch.vtensor<[4],f32>
|
||||||
|
|
Loading…
Reference in New Issue