mirror of https://github.com/llvm/torch-mlir
Add `aten.mm` to linalg lowering.
This is our first op with error semantics, and stresses the system. There are a few design notes of special interest: - RefineTypes.cpp's note about shape inference in the presence of code that dynamically produces and error, and it is provable statically. - ATenToLinalg.cpp's notes about future automation of the ATen->linalg path. - The notes in Passes.td about using low-tech `std.assert` ops instead of `shape.assuming`. Note: Doesn't work on IREE yet due to the `std.assert` op (needs to be lowered to `vm.fail` on the IREE side).pull/207/head
parent
28a0f02746
commit
f5dfa02523
|
@ -0,0 +1,51 @@
|
||||||
|
# -*- Python -*-
|
||||||
|
# This file is licensed under a pytorch-style license
|
||||||
|
# See frontends/pytorch/LICENSE for license information.
|
||||||
|
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch_mlir
|
||||||
|
|
||||||
|
import npcomp
|
||||||
|
from npcomp.compiler.pytorch.backend import refjit, frontend_lowering
|
||||||
|
from npcomp.compiler.utils import logging
|
||||||
|
|
||||||
|
import test_utils
|
||||||
|
|
||||||
|
logging.enable()
|
||||||
|
|
||||||
|
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||||
|
|
||||||
|
mb = torch_mlir.ModuleBuilder()
|
||||||
|
|
||||||
|
class TestModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
def forward(self, lhs, rhs):
|
||||||
|
return torch.mm(lhs, rhs)
|
||||||
|
|
||||||
|
test_module = TestModule()
|
||||||
|
class_annotator = torch_mlir.ClassAnnotator()
|
||||||
|
recursivescriptmodule = torch.jit.script(test_module)
|
||||||
|
torch.jit.save(recursivescriptmodule, '/tmp/foo.pt')
|
||||||
|
|
||||||
|
class_annotator.exportNone(recursivescriptmodule._c._type())
|
||||||
|
class_annotator.exportPath(recursivescriptmodule._c._type(), ['forward'])
|
||||||
|
class_annotator.annotateShapesAndDtypes(recursivescriptmodule._c._type(), ['forward'], [
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32),
|
||||||
|
([-1, -1], torch.float32),
|
||||||
|
])
|
||||||
|
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
|
||||||
|
mb.import_module(recursivescriptmodule._c, class_annotator)
|
||||||
|
#mb.module.operation.print()
|
||||||
|
|
||||||
|
backend = refjit.CompilerBackend()
|
||||||
|
compiled = backend.compile(frontend_lowering.lower_object_graph(mb.module))
|
||||||
|
jit_module = backend.load(compiled)
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
lhs = torch.rand(2, 3)
|
||||||
|
rhs = torch.rand(3, 4)
|
||||||
|
test_utils.compare_outputs(test_module.forward, jit_module.forward, lhs, rhs)
|
|
@ -0,0 +1,21 @@
|
||||||
|
//===------------------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef NPCOMP_CONVERSION_ATENTOLINALG_ATENTOLINALG_H
|
||||||
|
#define NPCOMP_CONVERSION_ATENTOLINALG_ATENTOLINALG_H
|
||||||
|
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace NPCOMP {
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> createConvertATenToLinalgPass();
|
||||||
|
}
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // NPCOMP_CONVERSION_ATENTOLINALG_ATENTOLINALG_H
|
|
@ -20,6 +20,85 @@ def ConvertATenToTCF : Pass<"convert-aten-to-tcf", "FuncOp"> {
|
||||||
let constructor = "mlir::NPCOMP::createConvertATenToTCFPass()";
|
let constructor = "mlir::NPCOMP::createConvertATenToTCFPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def ConvertATenToLinalg : Pass<"convert-aten-to-linalg", "FuncOp"> {
|
||||||
|
let summary = "Convert recognized ATen to Linalg ops";
|
||||||
|
let description = [{
|
||||||
|
Convert ATen ops to linalg ops.
|
||||||
|
|
||||||
|
This pass's main responsibility is to bridge the world between ops
|
||||||
|
that safely terminate the program in case of operand shape mismatches
|
||||||
|
(ATen) and ops where such mismatches are undefined behavior (linalg).
|
||||||
|
|
||||||
|
To model the termination of the program for implementing error guards,
|
||||||
|
we use the `std.assert` op.
|
||||||
|
This is a design decision that is at variance from other passes in npcomp,
|
||||||
|
such as `convert-tcf-to-std` and `convert-tcf-to-linalg` which use the
|
||||||
|
`shape` dialect's witness system (`shape.cstr_*` family of ops feeding into
|
||||||
|
`shape.assuming` regions). This is a change in design decisions
|
||||||
|
from those passes (which will be subsumed by this one). The reasons for this
|
||||||
|
change are heuristic, but boil down to:
|
||||||
|
1. The modeling of `shape.assuming` is odd, as it uses a region, which is
|
||||||
|
not a good fit for modeling error guards. Regions mark a "start" and an
|
||||||
|
"end" (which is their nesting property). But
|
||||||
|
modeling assertions in the program doesn't fit into that. For assertions,
|
||||||
|
only the "start" matters (once tested, a predicate remains true "forever"
|
||||||
|
-- it doesn't end at the "yield" of the region).
|
||||||
|
Thus, having regions places arbitrary "end"s that just add IR structure
|
||||||
|
that has no semantic value for modeling this problem! (and to make things
|
||||||
|
worse the "end"s, which we don't need, are what require "yielding"
|
||||||
|
values, which interrupts use-def chains). Consider the different
|
||||||
|
structural properties of regions:
|
||||||
|
a. IsolatedFromAbove region:
|
||||||
|
- "start" interrupts use-def chains,
|
||||||
|
- "end" interrupts use-def chains
|
||||||
|
- structurally protects from intra-block upward and downward
|
||||||
|
code motion
|
||||||
|
b. Capturing region (like `shape.assuming`):
|
||||||
|
- "start" does not interrupt use-def chains,
|
||||||
|
- "end" interrupts use-def chains
|
||||||
|
- structurally protects from intra-block upward and downward
|
||||||
|
code motion
|
||||||
|
c. What we "ideally" want:
|
||||||
|
- "start" interrupts use-def chains (can be pruned though)
|
||||||
|
- no "end" IR structure!
|
||||||
|
- structurally protects from intra-block upward code motion
|
||||||
|
(but not downward code motion!)
|
||||||
|
- Observation: We probably can't get all of this, but overall this
|
||||||
|
problem is much better suited for a "MemorySSA"-like
|
||||||
|
abstraction, call it "EffectSSA" which is constructed on-demand
|
||||||
|
based on MLIR's effect modeling system (rather than
|
||||||
|
`shape.assuming`, which only covers the effects the IR creator
|
||||||
|
encoded -- with witnesses/`shape.assuming` -- it is easy to forget
|
||||||
|
to handle effects other than those encoded in the
|
||||||
|
witness structure).
|
||||||
|
2. The presence of `shape.assuming` regions tends to create highly nested
|
||||||
|
IR structures, which don't interoperate well with any other IR
|
||||||
|
structures, and creates very bulky IR (and IR creation code). In general
|
||||||
|
if we are going to do anything with anything (e.g. canonicalize) we
|
||||||
|
end up needing need to either:
|
||||||
|
a. Flatten the `shape.assuming` IR (defeating the purpose of having
|
||||||
|
it).
|
||||||
|
b. Do some sort of shape.assuming "region merging".
|
||||||
|
c. Have special patterns that handle a subset of special cases (looking
|
||||||
|
through "yields" and such) and don't generalize.
|
||||||
|
3. Witnesses tend to encourage non-scalable peephole transformations, which
|
||||||
|
tend to make analyses/transformations non-robust to the presence of
|
||||||
|
control flow and side effecting ops (easy to forget to handle side
|
||||||
|
effects other than those modeled by the witness system).
|
||||||
|
4. All this code operates on ranked tensors, for which using individual
|
||||||
|
SSA values for sizes (rather than a "shape type") seems to
|
||||||
|
work really well at this level of abstraction based on prior experience
|
||||||
|
in IREE. (unranked code tends to benefit from having a discrete
|
||||||
|
"shape type" to model shapes).
|
||||||
|
|
||||||
|
We will see if we end up needing something like `shape.assuming`, but for
|
||||||
|
now, it seems likely we can do something simpler and just bypass it. The
|
||||||
|
design of having an EffectSSA that is constructed on-demand seems very
|
||||||
|
compelling for modeling effects more broadly.
|
||||||
|
}];
|
||||||
|
let constructor = "mlir::NPCOMP::createConvertATenToLinalgPass()";
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Basicpy conversions
|
// Basicpy conversions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -642,4 +642,3 @@ def aten_CopyInplaceOp: aten_Op<"copy.inplace", [DeclareOpInterfaceMethods<Torch
|
||||||
let results = (outs
|
let results = (outs
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,142 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "npcomp/Conversion/ATenToLinalg/ATenToLinalg.h"
|
||||||
|
|
||||||
|
#include "../PassDetail.h"
|
||||||
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h" // TODO: For `memref.dim`.
|
||||||
|
#include "mlir/Dialect/Traits.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::NPCOMP;
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Patterns (as this grows, it should be organized into multiple files)
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// This is going to eventually be O(#aten ops), which is in the 100s.
|
||||||
|
//
|
||||||
|
// Most of these patterns consist of:
|
||||||
|
// 1. Checking that the operand/result types and other static properties are
|
||||||
|
// good-enough to create a valid linalg op (such as operands being of
|
||||||
|
// ranks/dtypes acceptable to the linalg op).
|
||||||
|
// 2. Creating dynamic error guards, usually checking a predicate on the
|
||||||
|
// compatibility of operand shapes.
|
||||||
|
// 3. Creating init tensors for the computation op. Usually this involves
|
||||||
|
// reifying IR for a shape transfer function based on the operand shapes.
|
||||||
|
// 4. Creating a named linalg op to replace the original op.
|
||||||
|
//
|
||||||
|
// TODO: Use linalg OpDSL to autogenerate at least 1)/2)/3) such
|
||||||
|
// that these patterns become mostly mechanical associations of
|
||||||
|
// "aten.foo -> linalg.foo".
|
||||||
|
|
||||||
|
static LogicalResult verifyLinalgCompatibleTypes(Operation *op, PatternRewriter &rewriter) {
|
||||||
|
// For now, use a small allowlist of types we don't reject.
|
||||||
|
// The main culprit in practice is that !numpy.any_dtype might be present
|
||||||
|
// if shape/dtype inference wasn't good enough.
|
||||||
|
auto isValidLinalgType = [](Type type) {
|
||||||
|
if (auto rankedTensor = type.dyn_cast<RankedTensorType>()) {
|
||||||
|
if (BaseMemRefType::isValidElementType(rankedTensor.getElementType()))
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (type.isa<FloatType, IntegerType, IndexType>())
|
||||||
|
return true;
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
bool valid = llvm::all_of(op->getOperandTypes(), isValidLinalgType) &&
|
||||||
|
llvm::all_of(op->getResultTypes(), isValidLinalgType);
|
||||||
|
if (!valid)
|
||||||
|
return rewriter.notifyMatchFailure(op, "type cannot be lowered to linalg");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult convertMmOp(aten::MmOp op, PatternRewriter &rewriter) {
|
||||||
|
Location loc = op->getLoc();
|
||||||
|
Value lhs = op.getOperand(0);
|
||||||
|
Value rhs = op.getOperand(1);
|
||||||
|
|
||||||
|
// A user can write an errorneous program where `aten.mm` is in fact called
|
||||||
|
// with operands of invalid rank or dtype. We cannot convert to linalg in this
|
||||||
|
// case or we will get a verifier error, which corresponds to breaking of
|
||||||
|
// *internal* compiler invariants, and for a user manifests as a compiler
|
||||||
|
// crash in the worst case (such as we try to canonicalize/fold/print the
|
||||||
|
// invalid op before the verifier gets to see it -- also release builds of a
|
||||||
|
// mature copmiler usually have the verifier turned off for compile time
|
||||||
|
// reasons).
|
||||||
|
//
|
||||||
|
// The compiler cannot crash even if the user wrote an erroneous program!
|
||||||
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
|
return failure();
|
||||||
|
if (lhs.getType().cast<RankedTensorType>().getRank() != 2 ||
|
||||||
|
rhs.getType().cast<RankedTensorType>().getRank() != 2) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "expected both operands to aten.mm to be rank 2");
|
||||||
|
}
|
||||||
|
|
||||||
|
Value lhsDim0 = rewriter.create<memref::DimOp>(loc, lhs, 0);
|
||||||
|
Value lhsDim1 = rewriter.create<memref::DimOp>(loc, lhs, 1);
|
||||||
|
Value rhsDim0 = rewriter.create<memref::DimOp>(loc, rhs, 0);
|
||||||
|
Value rhsDim1 = rewriter.create<memref::DimOp>(loc, rhs, 1);
|
||||||
|
Value contractingDimEqual =
|
||||||
|
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, lhsDim1, rhsDim0);
|
||||||
|
rewriter.create<AssertOp>(
|
||||||
|
loc, contractingDimEqual,
|
||||||
|
rewriter.getStringAttr("mismatching contracting dimension for aten.mm"));
|
||||||
|
|
||||||
|
Type elementType = op.getType().cast<TensorType>().getElementType();
|
||||||
|
Value initTensor = rewriter.create<linalg::InitTensorOp>(
|
||||||
|
loc, ValueRange{lhsDim0, rhsDim1}, elementType);
|
||||||
|
Value c0 = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0.0));
|
||||||
|
Value zeroFill =
|
||||||
|
rewriter.create<linalg::FillOp>(loc, initTensor, c0).getResult(0);
|
||||||
|
Value matmul = rewriter
|
||||||
|
.create<linalg::MatmulOp>(loc, zeroFill.getType(),
|
||||||
|
ValueRange{lhs, rhs}, zeroFill)
|
||||||
|
.getResult(0);
|
||||||
|
// When constructed with just dynamic sizes, InitTensorOp will have a result
|
||||||
|
// type which has all `?`'s for dimensions, which might not be the result
|
||||||
|
// type of `op`. The constraints on later linalg ops means that the result of
|
||||||
|
// the MatmulOp will have this type too. So cast it to the desired type so
|
||||||
|
// that in the end we have the original result type.
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), matmul);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// The pass
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ConvertATenToLinalg
|
||||||
|
: public ConvertATenToLinalgBase<ConvertATenToLinalg> {
|
||||||
|
public:
|
||||||
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<memref::MemRefDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
|
(void)applyPatternsAndFoldGreedily(getOperation(), getPatterns());
|
||||||
|
}
|
||||||
|
|
||||||
|
FrozenRewritePatternList getPatterns() {
|
||||||
|
MLIRContext *context = &getContext();
|
||||||
|
RewritePatternSet patterns(context);
|
||||||
|
patterns.add(convertMmOp);
|
||||||
|
return std::move(patterns);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
|
mlir::NPCOMP::createConvertATenToLinalgPass() {
|
||||||
|
return std::make_unique<ConvertATenToLinalg>();
|
||||||
|
}
|
|
@ -0,0 +1,18 @@
|
||||||
|
add_npcomp_conversion_library(NPCOMPATenToLinalg
|
||||||
|
ATenToLinalg.cpp
|
||||||
|
|
||||||
|
ADDITIONAL_HEADER_DIRS
|
||||||
|
${PROJECT_SOURCE_DIR}/include/npcomp/Conversion/ATenToLinalg
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
NPCOMPConversionPassIncGen
|
||||||
|
|
||||||
|
LINK_COMPONENTS
|
||||||
|
Core
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRIR
|
||||||
|
MLIRPass
|
||||||
|
NPCOMPATenDialect
|
||||||
|
MLIRLinalg
|
||||||
|
)
|
|
@ -173,6 +173,5 @@ void mlir::NPCOMP::populateCoreATenToTCFPatterns(RewritePatternSet &patterns) {
|
||||||
patterns.add<ConvertUnary<aten::TanhOp, tcf::TanhOp>>(context);
|
patterns.add<ConvertUnary<aten::TanhOp, tcf::TanhOp>>(context);
|
||||||
patterns.add<ConvertBinaryElementwise<aten::MulOp, tcf::MulOp>>(context);
|
patterns.add<ConvertBinaryElementwise<aten::MulOp, tcf::MulOp>>(context);
|
||||||
patterns.add<ConvertBinaryElementwise<aten::MaximumOp, tcf::MaxOp>>(context);
|
patterns.add<ConvertBinaryElementwise<aten::MaximumOp, tcf::MaxOp>>(context);
|
||||||
patterns.add<ConvertBinaryElementwise<aten::MmOp, tcf::MatmulOp>>(context);
|
|
||||||
patterns.add<ConvertATenConv2d>(context);
|
patterns.add<ConvertATenConv2d>(context);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
add_subdirectory(ATenToLinalg)
|
||||||
add_subdirectory(ATenToTCF)
|
add_subdirectory(ATenToTCF)
|
||||||
add_subdirectory(BasicpyToStd)
|
add_subdirectory(BasicpyToStd)
|
||||||
add_subdirectory(NumpyToTCF)
|
add_subdirectory(NumpyToTCF)
|
||||||
|
|
|
@ -8,6 +8,7 @@
|
||||||
|
|
||||||
#include "npcomp/Conversion/Passes.h"
|
#include "npcomp/Conversion/Passes.h"
|
||||||
|
|
||||||
|
#include "npcomp/Conversion/ATenToLinalg/ATenToLinalg.h"
|
||||||
#include "npcomp/Conversion/ATenToTCF/Passes.h"
|
#include "npcomp/Conversion/ATenToTCF/Passes.h"
|
||||||
#include "npcomp/Conversion/BasicpyToStd/Passes.h"
|
#include "npcomp/Conversion/BasicpyToStd/Passes.h"
|
||||||
#include "npcomp/Conversion/NumpyToTCF/Passes.h"
|
#include "npcomp/Conversion/NumpyToTCF/Passes.h"
|
||||||
|
|
|
@ -75,7 +75,8 @@ bool operator==(const ValueKnowledge &lhs, const ValueKnowledge &rhs) {
|
||||||
std::make_tuple(rhs.hasRank, rhs.sizes, rhs.elementType);
|
std::make_tuple(rhs.hasRank, rhs.sizes, rhs.elementType);
|
||||||
}
|
}
|
||||||
|
|
||||||
// static llvm::raw_ostream &operator<<(llvm::raw_ostream &os, ValueKnowledge &knowledge) {
|
// static llvm::raw_ostream &operator<<(llvm::raw_ostream &os, ValueKnowledge
|
||||||
|
// &knowledge) {
|
||||||
// os << "hasRank = " << knowledge.hasRank << ", sizes = [";
|
// os << "hasRank = " << knowledge.hasRank << ", sizes = [";
|
||||||
// llvm::interleaveComma(knowledge.sizes, os);
|
// llvm::interleaveComma(knowledge.sizes, os);
|
||||||
// os << "]"
|
// os << "]"
|
||||||
|
@ -83,6 +84,16 @@ bool operator==(const ValueKnowledge &lhs, const ValueKnowledge &rhs) {
|
||||||
// return os;
|
// return os;
|
||||||
// }
|
// }
|
||||||
|
|
||||||
|
Type joinElementTypes(Type lhs, Type rhs) {
|
||||||
|
if (lhs.isa<Numpy::AnyDtypeType>())
|
||||||
|
return rhs;
|
||||||
|
if (rhs.isa<Numpy::AnyDtypeType>())
|
||||||
|
return lhs;
|
||||||
|
if (lhs == rhs)
|
||||||
|
return lhs;
|
||||||
|
return Numpy::AnyDtypeType::get(lhs.getContext());
|
||||||
|
}
|
||||||
|
|
||||||
// Given two pieces of static knowledge, calculate conservatively the
|
// Given two pieces of static knowledge, calculate conservatively the
|
||||||
// information we can be sure about.
|
// information we can be sure about.
|
||||||
ValueKnowledge join(const ValueKnowledge &lhs, const ValueKnowledge &rhs) {
|
ValueKnowledge join(const ValueKnowledge &lhs, const ValueKnowledge &rhs) {
|
||||||
|
@ -116,13 +127,7 @@ ValueKnowledge join(const ValueKnowledge &lhs, const ValueKnowledge &rhs) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!lhs.elementType || lhs.elementType.isa<Numpy::AnyDtypeType>()) {
|
result.elementType = joinElementTypes(lhs.elementType, rhs.elementType);
|
||||||
result.elementType = rhs.elementType;
|
|
||||||
} else if (!rhs.elementType || rhs.elementType.isa<Numpy::AnyDtypeType>()) {
|
|
||||||
result.elementType = lhs.elementType;
|
|
||||||
} else if (lhs.elementType == rhs.elementType) {
|
|
||||||
result.elementType = lhs.elementType;
|
|
||||||
}
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -157,6 +162,52 @@ private:
|
||||||
DenseMap<Value, ValueKnowledge> facts;
|
DenseMap<Value, ValueKnowledge> facts;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Return the knowledge for the results of an op, based on the knowledge of the
|
||||||
|
// operands and any information intrinsic to `op`.
|
||||||
|
static SmallVector<ValueKnowledge>
|
||||||
|
forwardKnowledgeTransferFunction(Operation *op,
|
||||||
|
ArrayRef<ValueKnowledge> operandKnowledge) {
|
||||||
|
if (isa<Numpy::TensorStaticInfoCastOp, aten::TanhOp>(op)) {
|
||||||
|
return {operandKnowledge[0]};
|
||||||
|
} else if (isa<aten::MmOp>(op)) {
|
||||||
|
auto &lhs = operandKnowledge[0];
|
||||||
|
auto &rhs = operandKnowledge[1];
|
||||||
|
auto knowledge =
|
||||||
|
ValueKnowledge::getMostConservativeKnowledge(op->getContext());
|
||||||
|
knowledge.hasRank = true;
|
||||||
|
// WARNING: We could be more precise here by calculating the output
|
||||||
|
// shape as "(lhs.shape[0], rhs.shape[1])". However, that is really tricky
|
||||||
|
// at this stage in the compiler because we don't really have many static
|
||||||
|
// guarantees about the input ranks because `aten` ops do dynamic error
|
||||||
|
// checking and safely abort the program. There is nothing preventing us
|
||||||
|
// from (correctly!) statically inferring the shapes of the operands to
|
||||||
|
// shapes that are guaranteed to cause an error at runtime.
|
||||||
|
//
|
||||||
|
// Example: Suppose a user program calls `aten.mm` with two rank-0 operands.
|
||||||
|
// The program emits an error when invoked, but when running this pass,
|
||||||
|
// we will (correctly!) infer `lhs.hasRank && lhs.sizes.size() == 0 &&
|
||||||
|
// rhs.hasRank && rhs.sizes.size() == 0` -- it's not safe to access
|
||||||
|
// `lhs.sizes[0]` / `rhs.sizes[1]`! So when writing this transfer
|
||||||
|
// function, it's not as simple as taking `lhs.sizes[0]` and `rhs.sizes[1]`,
|
||||||
|
// as both of those might read out of bounds of the array. It would require
|
||||||
|
// more complicated logic.
|
||||||
|
//
|
||||||
|
// Just knowing dtypes and ranks is sufficient at this stage
|
||||||
|
// in the compiler. The precise per-dimension size propagation is best done
|
||||||
|
// lower in the stack, such as at the linalg level, where we have more
|
||||||
|
// static guarantees and more structure.
|
||||||
|
knowledge.sizes.resize(2, kUnknownSize);
|
||||||
|
// TODO: Investigate promotion rules if element types mismatch.
|
||||||
|
// This is conservatively correct, assuming that if both element types are
|
||||||
|
// the same, then the result is of that same element type.
|
||||||
|
knowledge.elementType = joinElementTypes(lhs.elementType, rhs.elementType);
|
||||||
|
return {knowledge};
|
||||||
|
}
|
||||||
|
return SmallVector<ValueKnowledge>(
|
||||||
|
op->getNumResults(),
|
||||||
|
ValueKnowledge::getMostConservativeKnowledge(op->getContext()));
|
||||||
|
}
|
||||||
|
|
||||||
void TypeAnalyzer::propagate(Region ®ion) {
|
void TypeAnalyzer::propagate(Region ®ion) {
|
||||||
bool changed;
|
bool changed;
|
||||||
do {
|
do {
|
||||||
|
@ -168,10 +219,13 @@ void TypeAnalyzer::propagate(Region ®ion) {
|
||||||
for (Operation &op : block.getOperations()) {
|
for (Operation &op : block.getOperations()) {
|
||||||
for (Value v : op.getResults())
|
for (Value v : op.getResults())
|
||||||
changed |= incorporateKnowledge(v, getKnowledgeFromType(v.getType()));
|
changed |= incorporateKnowledge(v, getKnowledgeFromType(v.getType()));
|
||||||
if (isa<Numpy::TensorStaticInfoCastOp, aten::TanhOp>(op)) {
|
auto operandKnowledge = llvm::to_vector<6>(llvm::map_range(
|
||||||
changed |= incorporateKnowledge(op.getResult(0),
|
op.getOperands(), [&](Value v) { return getKnowledge(v); }));
|
||||||
getKnowledge(op.getOperand(0)));
|
SmallVector<ValueKnowledge> resultKnowledge =
|
||||||
}
|
forwardKnowledgeTransferFunction(&op, operandKnowledge);
|
||||||
|
assert(resultKnowledge.size() == op.getNumResults());
|
||||||
|
for (auto t : llvm::zip(op.getResults(), resultKnowledge))
|
||||||
|
changed |= incorporateKnowledge(std::get<0>(t), std::get<1>(t));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} while (changed);
|
} while (changed);
|
||||||
|
|
|
@ -61,6 +61,7 @@ TORCH_TO_TCP_PASSES = (
|
||||||
# Lower to TCP (+ guards) which is the input to codegen backends.
|
# Lower to TCP (+ guards) which is the input to codegen backends.
|
||||||
# Most of this should be subsumed by aten->linalg+guards conversions.
|
# Most of this should be subsumed by aten->linalg+guards conversions.
|
||||||
# (the guard generation will be automated from the linalg Op DSL)
|
# (the guard generation will be automated from the linalg Op DSL)
|
||||||
|
"func(convert-aten-to-linalg)",
|
||||||
"func(convert-aten-to-tcf)",
|
"func(convert-aten-to-tcf)",
|
||||||
"func(convert-tcf-to-std)",
|
"func(convert-tcf-to-std)",
|
||||||
"func(convert-elementwise-to-linalg)",
|
"func(convert-elementwise-to-linalg)",
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
// RUN: npcomp-opt <%s -convert-aten-to-linalg | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @aten.mm$basic(
|
||||||
|
// CHECK-SAME: %[[LHS:.*]]: tensor<?x?xf32>,
|
||||||
|
// CHECK-SAME: %[[RHS:.*]]: tensor<?x?xf32>) -> tensor<?x2xf32> {
|
||||||
|
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||||
|
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||||
|
// CHECK: %[[CF0:.*]] = constant 0.000000e+00 : f32
|
||||||
|
// CHECK: %[[LHS_DIM_0:.*]] = memref.dim %[[LHS]], %[[C0]] : tensor<?x?xf32>
|
||||||
|
// CHECK: %[[LHS_DIM_1:.*]] = memref.dim %[[LHS]], %[[C1]] : tensor<?x?xf32>
|
||||||
|
// CHECK: %[[RHS_DIM_0:.*]] = memref.dim %[[RHS]], %[[C0]] : tensor<?x?xf32>
|
||||||
|
// CHECK: %[[RHS_DIM_1:.*]] = memref.dim %[[RHS]], %[[C1]] : tensor<?x?xf32>
|
||||||
|
// CHECK: %[[EQ:.*]] = cmpi eq, %[[LHS_DIM_1]], %[[RHS_DIM_0]] : index
|
||||||
|
// CHECK: assert %[[EQ]], "mismatching contracting dimension for aten.mm"
|
||||||
|
// CHECK: %[[INIT_TENSOR:.*]] = linalg.init_tensor [%[[LHS_DIM_0]], %[[RHS_DIM_1]]] : tensor<?x?xf32>
|
||||||
|
// CHECK: %[[ZEROFILL:.*]] = linalg.fill(%[[INIT_TENSOR]], %[[CF0]]) : tensor<?x?xf32>, f32 -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[LHS]], %[[RHS]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ZEROFILL]] : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[CASTED:.*]] = tensor.cast %[[MATMUL]] : tensor<?x?xf32> to tensor<?x2xf32>
|
||||||
|
// CHECK: return %[[CASTED]] : tensor<?x2xf32>
|
||||||
|
func @aten.mm$basic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x2xf32> {
|
||||||
|
%0 = "aten.mm"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x2xf32>
|
||||||
|
return %0 : tensor<?x2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @aten.mm$no_convert$missing_dtype
|
||||||
|
func @aten.mm$no_convert$missing_dtype(%arg0: tensor<*x!numpy.any_dtype>, %arg1: tensor<*x!numpy.any_dtype>) -> tensor<*x!numpy.any_dtype> {
|
||||||
|
// CHECK-NEXT: aten.mm
|
||||||
|
%0 = "aten.mm"(%arg0, %arg1) : (tensor<*x!numpy.any_dtype>, tensor<*x!numpy.any_dtype>) -> tensor<*x!numpy.any_dtype>
|
||||||
|
return %0 : tensor<*x!numpy.any_dtype>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @aten.mm$no_convert$wrong_rank
|
||||||
|
func @aten.mm$no_convert$wrong_rank(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<*x!numpy.any_dtype> {
|
||||||
|
// CHECK-NEXT: aten.mm
|
||||||
|
%0 = "aten.mm"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<*x!numpy.any_dtype>
|
||||||
|
return %0 : tensor<*x!numpy.any_dtype>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @aten.mm$no_convert$result_missing_dtype
|
||||||
|
func @aten.mm$no_convert$result_missing_dtype(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<*xf32> {
|
||||||
|
// CHECK-NEXT: aten.mm
|
||||||
|
%0 = "aten.mm"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<*xf32>
|
||||||
|
return %0 : tensor<*xf32>
|
||||||
|
}
|
|
@ -24,6 +24,19 @@ func @f(%arg0: tensor<2x3x?xf32>) -> tensor<*x!numpy.any_dtype> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @f(
|
||||||
|
// CHECK-SAME: %[[LHS:.*]]: tensor<2x?xf32>,
|
||||||
|
// CHECK-SAME: %[[RHS:.*]]: tensor<?x?xf32>) -> tensor<*x!numpy.any_dtype> {
|
||||||
|
// CHECK: %[[MM:.*]] = "aten.mm"(%[[LHS]], %[[RHS]]) : (tensor<2x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[SHAPE_ERASED:.*]] = numpy.tensor_static_info_cast %[[MM]] : tensor<?x?xf32> to tensor<*x!numpy.any_dtype>
|
||||||
|
// CHECK: return %[[SHAPE_ERASED]] : tensor<*x!numpy.any_dtype>
|
||||||
|
func @f(%arg0: tensor<2x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<*x!numpy.any_dtype> {
|
||||||
|
%1 = "aten.mm"(%arg0, %arg1) : (tensor<2x?xf32>, tensor<?x?xf32>) -> tensor<*x!numpy.any_dtype>
|
||||||
|
return %1 : tensor<*x!numpy.any_dtype>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @f
|
// CHECK-LABEL: func @f
|
||||||
func @f(%arg0: tensor<2x3x?xf32>) -> tensor<*x!numpy.any_dtype> {
|
func @f(%arg0: tensor<2x3x?xf32>) -> tensor<*x!numpy.any_dtype> {
|
||||||
// Check propagation through multiple ops.
|
// Check propagation through multiple ops.
|
||||||
|
|
Loading…
Reference in New Issue