/**
* Copyright 2014 Duan Bingnan
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.pinus4j.entity;
import java.io.File;
import java.io.FileFilter;
import java.lang.reflect.Field;
import java.net.JarURLConnection;
import java.net.URL;
import java.net.URLDecoder;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
import org.pinus4j.cluster.beans.IShardingKey;
import org.pinus4j.cluster.beans.ShardingKey;
import org.pinus4j.constant.Const;
import org.pinus4j.entity.annotations.CacheVersion;
import org.pinus4j.entity.annotations.DateTime;
import org.pinus4j.entity.annotations.Index;
import org.pinus4j.entity.annotations.Indexes;
import org.pinus4j.entity.annotations.PrimaryKey;
import org.pinus4j.entity.annotations.Table;
import org.pinus4j.entity.annotations.UpdateTime;
import org.pinus4j.entity.meta.DBTable;
import org.pinus4j.entity.meta.DBTableColumn;
import org.pinus4j.entity.meta.DBTableIndex;
import org.pinus4j.entity.meta.DBTablePK;
import org.pinus4j.entity.meta.DataTypeBind;
import org.pinus4j.entity.meta.EntityPK;
import org.pinus4j.entity.meta.PKName;
import org.pinus4j.entity.meta.PKValue;
import org.pinus4j.exceptions.DBOperationException;
import org.pinus4j.utils.BeansUtil;
import org.pinus4j.utils.StringUtil;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
/**
* 管理加载的Entity信息.
*
* @author shanwei Jul 22, 2015 1:33:30 PM
*/
public class DefaultEntityMetaManager implements IEntityMetaManager {
/**
* 当前线程的类装载器. 用于扫描可以生成数据表的数据对象.
*/
private final ClassLoader classloader = Thread.currentThread().getContextClassLoader();
private final static Map<Class<?>, DBTable> tableMap = new HashMap<Class<?>, DBTable>();
private final static Map<String, DBTable> tableNameMap = Maps.newHashMap();
private final static List<DBTable> tables = new ArrayList<DBTable>();
private volatile static IEntityMetaManager instance;
private DefaultEntityMetaManager() {
}
/**
* 获取对象实例.
*
* @return
*/
public static IEntityMetaManager getInstance() {
if (instance == null) {
synchronized (DefaultEntityMetaManager.class) {
if (instance == null) {
instance = new DefaultEntityMetaManager();
}
}
}
return instance;
}
@Override
public PKValue getNotUnionPkValue(Object obj) {
PKName pkName = getNotUnionPkName(obj.getClass());
Object pkValue = BeansUtil.getProperty(obj, pkName.getValue());
return PKValue.valueOf(pkValue);
}
@Override
public EntityPK getEntityPK(Object obj) {
PKName[] pkNames = getPkName(obj.getClass());
List<PKValue> pkValues = Lists.newArrayList();
Object pkValue = null;
for (PKName pkName : pkNames) {
pkValue = BeansUtil.getProperty(obj, pkName.getValue());
pkValues.add(PKValue.valueOf(pkValue));
}
return EntityPK.valueOf(pkNames, pkValues.toArray(new PKValue[pkValues.size()]));
}
@Override
public PKName getNotUnionPkName(Class<?> clazz) {
DBTable dbTable = this.getTableMeta(clazz);
if (dbTable.isUnionPrimaryKey()) {
throw new IllegalStateException("不支持联合主键, class=" + clazz);
}
List<DBTablePK> primaryKeys = dbTable.getPrimaryKeys();
if (primaryKeys.isEmpty()) {
throw new IllegalStateException("找不到主键 class=" + clazz);
}
return primaryKeys.get(0).getPKName();
}
@Override
public PKName[] getPkName(Class<?> clazz) {
DBTable dbTable = this.getTableMeta(clazz);
List<DBTablePK> primaryKeys = dbTable.getPrimaryKeys();
if (primaryKeys.isEmpty()) {
throw new IllegalStateException("找不到主键 class=" + clazz);
}
List<PKName> ePKList = new ArrayList<PKName>(primaryKeys.size());
for (DBTablePK primaryKey : primaryKeys) {
ePKList.add(primaryKey.getPKName());
}
return ePKList.toArray(new PKName[ePKList.size()]);
}
@Override
public boolean isShardingEntity(Class<?> clazz) {
DBTable dbTable = this.getTableMeta(clazz);
return dbTable.isSharding();
}
@Override
public IShardingKey<?> getShardingKey(Object entity) {
Class<?> clazz = entity.getClass();
DBTable dbTable = this.getTableMeta(clazz);
String clusterName = dbTable.getCluster();
String shardingField = dbTable.getShardingBy();
Object shardingValue = null;
try {
shardingValue = BeansUtil.getProperty(entity, shardingField);
} catch (Exception e) {
throw new DBOperationException("获取sharding value失败, clazz=" + clazz + " field=" + shardingField);
}
if (shardingValue == null) {
throw new IllegalStateException("shardingValue is null, clazz=" + clazz + " field=" + shardingField);
}
return new ShardingKey<Object>(clusterName, shardingValue);
}
@Override
public String getClusterName(Class<?> clazz) {
DBTable dbTable = this.getTableMeta(clazz);
return dbTable.getCluster();
}
@Override
public int getTableNum(Class<?> clazz) {
DBTable dbTable = this.getTableMeta(clazz);
return dbTable.getShardingNum();
}
@Override
public String getTableName(Object entity, int tableIndex) {
Class<?> entityClass = entity.getClass();
return getTableName(entityClass, tableIndex);
}
@Override
public String getTableName(Class<?> clazz, int tableIndex) {
if (tableIndex == -1) {
return getTableName(clazz);
} else {
return getTableName(clazz) + tableIndex;
}
}
@Override
public String getTableName(Class<?> clazz) {
DBTable dbTable = this.getTableMeta(clazz);
return dbTable.getName();
}
@Override
public boolean isCache(Class<?> clazz) {
DBTable dbTable = this.getTableMeta(clazz);
return dbTable.isCache();
}
@Override
public boolean isUnionKey(Class<?> clazz) {
DBTable dbTable = getTableMeta(clazz);
return dbTable.isUnionPrimaryKey();
}
@Override
public DBTablePK getNotUnionPrimaryKey(Class<?> clazz) {
DBTable dbTable = getTableMeta(clazz);
List<DBTablePK> dbTablePK = dbTable.getPrimaryKeys();
if (dbTablePK.size() > 1) {
throw new IllegalStateException("不支持联合主键, class=" + clazz);
}
return dbTablePK.get(0);
}
@Override
public void reloadEntity(String scanPackage) {
synchronized (this) {
tableMap.clear();
tables.clear();
loadEntity(scanPackage);
}
}
/**
* 扫描包并发现使用Table注解的对象.
*/
@Override
public void loadEntity(String scanPackage) {
try {
String pkgDirName = scanPackage.replace(".", "/");
Enumeration<URL> dirs = classloader.getResources(pkgDirName);
URL url = null;
DBTable dbTable = null;
while (dirs.hasMoreElements()) {
url = dirs.nextElement();
String protocol = url.getProtocol();
if (protocol.equals("file")) {
String filePath = URLDecoder.decode(url.getFile(), "utf-8");
addClassesByFile(tables, scanPackage, filePath);
} else if (protocol.equals("jar")) {
JarFile jar = null;
jar = ((JarURLConnection) url.openConnection()).getJarFile();
Enumeration<JarEntry> entries = jar.entries();
while (entries.hasMoreElements()) {
JarEntry entry = entries.nextElement();
String name = entry.getName();
if (name.charAt(0) == '/') {
name = name.substring(1);
}
if (!name.startsWith(pkgDirName)) {
continue;
}
if (name.endsWith(".class") && !entry.isDirectory()) {
String className = name.substring(scanPackage.length() + 1, name.length() - 6).replace("/",
".");
Class<?> tableClass = classloader.loadClass(scanPackage + "." + className);
if (tableClass.getAnnotation(Table.class) != null) {
dbTable = converTo(tableClass);
tables.add(dbTable);
tableMap.put(tableClass, dbTable);
tableNameMap.put(dbTable.getName(), dbTable);
}
}
}
}
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
private void addClassesByFile(List<DBTable> tables, String packageName, String packagePath)
throws ClassNotFoundException {
File dir = new File(packagePath);
if (!dir.exists() || !dir.isDirectory()) {
return;
}
File[] dirfiles = dir.listFiles(new FileFilter() {
public boolean accept(File file) {
return file.isDirectory() || file.getName().endsWith(".class");
}
});
DBTable dbTable = null;
for (File file : dirfiles) {
if (file.isDirectory()) {
addClassesByFile(tables, packageName + "." + file.getName(), file.getAbsolutePath());
} else {
String className = file.getName().substring(0, file.getName().length() - 6);
Class<?> tableClass = classloader.loadClass(packageName + "." + className);
if (tableClass.getAnnotation(Table.class) != null) {
dbTable = converTo(tableClass);
tables.add(dbTable);
tableMap.put(tableClass, dbTable);
tableNameMap.put(dbTable.getName(), dbTable);
}
}
}
}
/**
* 通过翻身将class转换为DBTable对象
*/
protected DBTable converTo(Class<?> defClass) {
if (defClass == null) {
throw new IllegalArgumentException("被转化的Java对象不能为空");
}
Class<?> clazz;
try {
clazz = defClass.newInstance().getClass();
} catch (Exception e) {
throw new RuntimeException(e);
}
// 解析DBTable
Table annoTable = clazz.getAnnotation(Table.class);
if (annoTable == null) {
throw new IllegalArgumentException(clazz + "无法转化为数据库,请使用@Table注解");
}
// 获取表名
String tableName = StringUtil.isBlank(annoTable.name()) ? clazz.getSimpleName() : annoTable.name();
DBTable table = new DBTable(tableName.toLowerCase());
// 解析cache version
String cacheVersion = Const.DEFAULT_CACHE_VERSION;
CacheVersion annoCacheVersion = clazz.getAnnotation(CacheVersion.class);
if (annoCacheVersion != null && StringUtil.isNotBlank(annoCacheVersion.value())) {
cacheVersion = annoCacheVersion.value();
}
table.setCacheVersion(cacheVersion);
// 获取集群名
String cluster = annoTable.cluster();
if (StringUtil.isBlank(cluster)) {
throw new IllegalArgumentException(clazz + " @Table的cluster不能为空");
}
table.setCluster(cluster);
// 获取分片字段
String shardingBy = annoTable.shardingBy();
table.setShardingBy(shardingBy);
// 获取分表数
int shardingNum = annoTable.shardingNum();
table.setShardingNum(shardingNum);
// 是否需要被缓存
boolean isCache = annoTable.cache();
table.setCache(isCache);
// 解析DBIndex
_parseDBIndex(table, clazz);
// 解析DBTableColumn
DBTablePK primaryKey = null;
DBTableColumn column = null;
org.pinus4j.entity.annotations.Field dbField = null;
PrimaryKey pk = null;
UpdateTime updateTime = null;
DateTime datetime = null;
for (Field f : clazz.getDeclaredFields()) {
//
// Datatime
//
datetime = f.getAnnotation(DateTime.class);
if (datetime != null) {
if (f.getType() != Date.class) {
throw new IllegalArgumentException(clazz + " " + f.getName() + " " + f.getType() + " 无法转化为日期字段");
}
String fieldName = f.getName();
if (StringUtil.isNotBlank(datetime.name())) {
fieldName = datetime.name();
}
BeansUtil.putAliasField(clazz, fieldName, f);
column = new DBTableColumn();
column.setField(fieldName);
column.setType(DataTypeBind.DATETIME.getDBType());
column.setHasDefault(datetime.hasDefault());
if (column.isHasDefault())
column.setDefaultValue(DataTypeBind.DATETIME.getDefaultValue());
column.setComment(datetime.comment());
table.addColumn(column);
}
//
// UpdateTime
//
updateTime = f.getAnnotation(UpdateTime.class);
if (updateTime != null) {
if (f.getType() != java.sql.Timestamp.class) {
throw new IllegalArgumentException(clazz + " " + f.getName() + " " + f.getType() + " 无法转化为时间戳字段");
}
String fieldName = f.getName();
if (StringUtil.isNotBlank(updateTime.name())) {
fieldName = updateTime.name();
}
BeansUtil._aliasFieldCache.put(clazz.getName() + fieldName, f);
column = new DBTableColumn();
column.setField(fieldName);
column.setType(DataTypeBind.UPDATETIME.getDBType());
column.setHasDefault(true);
column.setDefaultValue(DataTypeBind.UPDATETIME.getDefaultValue());
column.setComment(updateTime.comment());
table.addColumn(column);
}
//
// Field
//
dbField = f.getAnnotation(org.pinus4j.entity.annotations.Field.class);
if (dbField != null) {
if (f.getType() == java.sql.Timestamp.class) {
throw new IllegalArgumentException(clazz + " " + f.getName() + "应该是时间戳类型,必须使用@UpdateTime标注");
}
if (f.getType() == java.util.Date.class) {
throw new IllegalArgumentException(clazz + " " + f.getName() + "应该是日期类型,必须使用@DateTime标注");
}
String fieldName = f.getName();
if (StringUtil.isNotBlank(dbField.name())) {
fieldName = dbField.name();
}
BeansUtil._aliasFieldCache.put(clazz.getName() + fieldName, f);
boolean isCanNull = dbField.isCanNull();
int length = _getLength(f, dbField.length());
boolean hasDefault = dbField.hasDefault();
column = new DBTableColumn();
column.setField(fieldName);
column.setType(DataTypeBind.getEnum(f.getType()).getDBType());
column.setCanNull(isCanNull);
column.setLength(length);
column.setHasDefault(hasDefault);
column.setComment(dbField.comment());
if (column.isHasDefault())
column.setDefaultValue(DataTypeBind.getEnum(f.getType()).getDefaultValue());
// 如果字符串长度超过指定长度则使用text类型
if (column.getType().equals(DataTypeBind.STRING.getDBType())
&& column.getLength() > Const.COLUMN_TEXT_LENGTH) {
column.setType(DataTypeBind.TEXT.getDBType());
column.setHasDefault(false); // text default value gen by pinus, not db.
column.setLength(0);
column.setDefaultValue(DataTypeBind.TEXT.getDefaultValue());
}
// 如果字段为boolean则长度为1
if (column.getType().equals(DataTypeBind.BOOL.getDBType())) {
column.setLength(1);
}
table.addColumn(column);
}
//
// PrimaryKey
//
pk = f.getAnnotation(PrimaryKey.class);
if (pk != null) {
String fieldName = f.getName();
if (StringUtil.isNotBlank(pk.name())) {
fieldName = pk.name();
}
BeansUtil._aliasFieldCache.put(clazz.getName() + fieldName, f);
primaryKey = new DBTablePK();
primaryKey.setField(fieldName);
DataTypeBind dbType = DataTypeBind.getEnum(f.getType());
primaryKey.setType(dbType.getDBType());
int length = _getLength(f, pk.length());
primaryKey.setLength(length);
primaryKey.setComment(pk.comment());
primaryKey.setAutoIncrement(pk.isAutoIncrement());
table.addPrimaryKey(primaryKey);
// primary key also is a table column
table.addColumn(primaryKey.toTableColumn());
}
}
// check primary key
table.checkPrimaryKey();
if (table.getColumns().isEmpty()) {
throw new IllegalStateException(clazz + "被转化的java对象没有包含任何列属性" + defClass);
}
return table;
}
/**
* 解析@Indexes注解.
*/
private static void _parseDBIndex(DBTable table, Class<?> clazz) {
Indexes annoIndexes = clazz.getAnnotation(Indexes.class);
if (annoIndexes == null) {
return;
}
Index[] indexes = annoIndexes.value();
if (indexes == null || indexes.length <= 0) {
throw new IllegalArgumentException("索引注解错误, " + clazz);
}
DBTableIndex dbIndex = null;
for (Index index : indexes) {
dbIndex = new DBTableIndex();
List<String> indexFields = Arrays.asList(StringUtil.removeBlank(index.field()).split(","));
dbIndex.setFields(indexFields);
dbIndex.setUnique(index.isUnique());
table.addIndex(dbIndex);
}
}
private static int _getLength(Field f, int annoLength) {
int length = annoLength;
if (length > 0) {
return length;
}
DataTypeBind dbType = DataTypeBind.getEnum(f.getType());
switch (dbType) {
case STRING:
length = 255;
break;
case BYTE:
length = 4;
break;
case SHORT:
length = 6;
break;
case INT:
length = 11;
break;
case LONG:
length = 20;
break;
default:
break;
}
return length;
}
@Override
public DBTable getTableMeta(Class<?> clazz) {
DBTable dbTable = tableMap.get(clazz);
if (dbTable == null) {
throw new IllegalStateException("找不到实体的元信息 class=" + clazz);
}
return dbTable;
}
@Override
public DBTable getTableMeta(String tableName) {
DBTable dbTable = tableNameMap.get(tableName);
if (dbTable == null) {
throw new IllegalStateException("找不到实体的元信息 table name=" + tableName);
}
return dbTable;
}
@Override
public List<DBTable> getTableMetaList() {
return tables;
}
}